From 3131b0c8390e0de6faf240495f8a89c7e0f93471 Mon Sep 17 00:00:00 2001 From: Xargana Date: Mon, 14 Jul 2025 13:44:23 +0300 Subject: [PATCH] why is there optimization in my racism app (d*scord) --- DATABASE_OPTIMIZATION.md | 139 ++++++++++ cli.py | 32 ++- src/client.py | 19 +- src/database.py | 585 +++++++++++++++++++++++++++------------ src/rate_limiter.py | 66 +++++ 5 files changed, 640 insertions(+), 201 deletions(-) create mode 100644 DATABASE_OPTIMIZATION.md diff --git a/DATABASE_OPTIMIZATION.md b/DATABASE_OPTIMIZATION.md new file mode 100644 index 0000000..ce7025e --- /dev/null +++ b/DATABASE_OPTIMIZATION.md @@ -0,0 +1,139 @@ +# Database Optimization for 50k+ Users + +## 1. Database Schema Improvements + +### Add Indexes +```sql +-- Primary performance indexes +CREATE INDEX idx_users_username ON users(username); +CREATE INDEX idx_users_display_name ON users(display_name); +CREATE INDEX idx_users_updated_at ON users(updated_at); +CREATE INDEX idx_user_servers_server_id ON user_servers(server_id); + +-- Composite indexes for common queries +CREATE INDEX idx_users_username_display ON users(username, display_name); +``` + +### Query Optimization +```sql +-- Instead of GROUP_CONCAT, use separate queries +SELECT * FROM users WHERE user_id = ?; +SELECT server_id FROM user_servers WHERE user_id = ?; +``` + +## 2. Connection Pool Implementation + +Replace single connection with proper pooling: + +```python +import aiomysql + +async def create_pool(): + return await aiomysql.create_pool( + host='localhost', + port=3306, + user='user', + password='password', + db='database', + minsize=5, + maxsize=20, + charset='utf8mb4' + ) +``` + +## 3. Pagination Implementation + +### Database Layer +```python +async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]: + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT * FROM users + ORDER BY user_id + LIMIT %s OFFSET %s + """, (limit, offset)) +``` + +### CLI Layer +```python +async def search_user_paginated(query: str, page: int = 1, per_page: int = 50): + offset = (page - 1) * per_page + users = await database.search_users(query, offset, per_page) + # Display with pagination controls +``` + +## 4. Search Optimization + +### Full-Text Search +```sql +-- Add full-text index for better search +ALTER TABLE users ADD FULLTEXT(username, display_name, bio); + +-- Use full-text search instead of LIKE +SELECT * FROM users +WHERE MATCH(username, display_name) AGAINST(? IN BOOLEAN MODE); +``` + +### Cached Search Results +```python +# Cache frequent searches +from functools import lru_cache +import asyncio + +@lru_cache(maxsize=1000) +async def cached_user_search(query: str): + return await database.search_users(query) +``` + +## 5. Batch Operations + +### Bulk Inserts +```python +async def save_users_batch(self, users: List[UserData]): + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + # Use executemany for bulk operations + await cursor.executemany(""" + INSERT INTO users (...) VALUES (...) + ON DUPLICATE KEY UPDATE ... + """, [(user.user_id, user.username, ...) for user in users]) +``` + +## 6. Rate Limiting Improvements + +### Smarter Rate Limiting +```python +class AdaptiveRateLimiter: + def __init__(self): + self.base_delay = 1.0 + self.consecutive_429s = 0 + + async def wait(self): + if self.consecutive_429s > 0: + delay = self.base_delay * (2 ** self.consecutive_429s) + await asyncio.sleep(min(delay, 60)) # Cap at 60 seconds + else: + await asyncio.sleep(self.base_delay) +``` + +## Performance Estimates (50k users): + +### Current Implementation: +- **get_all_users()**: ~30-60 seconds + 2-4GB RAM +- **CLI search**: ~10-30 seconds per search +- **Database saves**: ~5-10x slower due to locking + +### Optimized Implementation: +- **Paginated queries**: ~0.1-0.5 seconds per page +- **Indexed search**: ~0.1-1 second per search +- **Connection pool**: ~2-3x faster concurrent operations +- **Memory usage**: ~50-100MB instead of GB + +## Implementation Priority: + +1. **High Priority**: Connection pooling, pagination, indexes +2. **Medium Priority**: Search optimization, batch operations +3. **Low Priority**: Caching, adaptive rate limiting + +The optimized version should handle 50k users with reasonable performance (~10x slower than 500 users instead of 100x slower). diff --git a/cli.py b/cli.py index 349a1a2..8701985 100644 --- a/cli.py +++ b/cli.py @@ -52,27 +52,22 @@ async def show_stats(): print(f" Server {server_id}: {user_count} users") -async def search_user(query: str): - """Search for users.""" +async def search_user(query: str, page: int = 1, per_page: int = 10): + """Search for users with pagination.""" config = Config() database = await create_database(mariadb_config=config.get_mariadb_config()) - all_users = await database.get_all_users() - - # Search by username or user ID - results = [] - for user in all_users: - if (query.lower() in user.username.lower() or - query.lower() in (user.display_name or "").lower() or - query == str(user.user_id)): - results.append(user) + offset = (page - 1) * per_page + results = await database.search_users(query, offset, per_page) if not results: print("No users found matching the query.") return - print(f"\n=== Found {len(results)} users ===") - for user in results[:10]: # Show first 10 results + total_count = await database.get_user_count_total() + total_pages = (total_count + per_page - 1) // per_page + print(f"\n=== Search Results (Page {page} of {total_pages}) ===") + for user in results: print(f"{user.username}#{user.discriminator} (ID: {user.user_id})") if user.display_name: print(f" Display name: {user.display_name}") @@ -98,6 +93,13 @@ async def search_user(query: str): print(f" Last updated: {user.updated_at}") print() + + # Add pagination navigation info + print(f"\nShowing {len(results)} of {total_count} total users") + if total_pages > 1: + print(f"Use --page {page + 1} for next page" if page < total_pages else "Last page") + + await database.close() async def list_servers(): @@ -292,6 +294,8 @@ def main(): # Search command search_parser = subparsers.add_parser("search", help="Search for users") search_parser.add_argument("query", help="Search query (username or user ID)") + search_parser.add_argument("--page", type=int, default=1, help="Page number (default: 1)") + search_parser.add_argument("--per-page", type=int, default=10, help="Results per page (default: 10)") # Server commands servers_parser = subparsers.add_parser("servers", help="List all servers with user counts") @@ -323,7 +327,7 @@ def main(): elif args.command == "stats": asyncio.run(show_stats()) elif args.command == "search": - asyncio.run(search_user(args.query)) + asyncio.run(search_user(args.query, args.page, getattr(args, 'per_page', 10))) elif args.command == "servers": asyncio.run(list_servers()) elif args.command == "user-servers": diff --git a/src/client.py b/src/client.py index 5b672fc..911dbf6 100644 --- a/src/client.py +++ b/src/client.py @@ -15,7 +15,7 @@ except ImportError: from .config import Config from .database import UserData -from .rate_limiter import RateLimiter +from .rate_limiter import RateLimiter, AdaptiveRateLimiter class DiscordDataClient(discord.Client): @@ -26,9 +26,9 @@ class DiscordDataClient(discord.Client): self.config = config self.database = database - self.rate_limiter = RateLimiter( - requests_per_minute=config.max_requests_per_minute, - delay_between_requests=config.request_delay + self.rate_limiter = AdaptiveRateLimiter( + base_delay=config.request_delay, + max_delay=60.0 ) self.logger = logging.getLogger(__name__) @@ -251,7 +251,16 @@ class DiscordDataClient(discord.Client): await self.rate_limiter.wait() # Use fetch_user_profile to get mutual guilds - profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True) + try: + profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True) + self.rate_limiter.on_success() + except discord.HTTPException as e: + if e.status == 429: # Rate limited + retry_after = getattr(e, 'retry_after', None) + self.rate_limiter.on_rate_limit(retry_after) + raise + else: + raise if hasattr(profile, 'mutual_guilds') and profile.mutual_guilds: mutual_guild_ids = [] diff --git a/src/database.py b/src/database.py index 745887a..e8b9188 100644 --- a/src/database.py +++ b/src/database.py @@ -15,7 +15,7 @@ import logging # Optional MariaDB support try: - from asyncmy import connect + from asyncmy import connect, create_pool from asyncmy.cursors import DictCursor from asyncmy.errors import MySQLError MARIADB_AVAILABLE = True @@ -255,6 +255,44 @@ class JSONDatabase: result[server_id] = servers[str(server_id)] return result + async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]: + """Get users with pagination (JSON implementation).""" + async with self._lock: + data = self._load_data() + users = [] + for user_id, user_data in data.items(): + if user_id != "servers": # Skip servers data + users.append(UserData.from_dict(user_data)) + + # Sort by user_id and paginate + users.sort(key=lambda u: u.user_id) + return users[offset:offset + limit] + + async def search_users(self, query: str, offset: int = 0, limit: int = 100) -> List[UserData]: + """Search users with pagination (JSON implementation).""" + async with self._lock: + data = self._load_data() + matching_users = [] + + for user_id, user_data in data.items(): + if user_id != "servers": # Skip servers data + user = UserData.from_dict(user_data) + if (query.lower() in user.username.lower() or + query.lower() in (user.display_name or "").lower() or + query == str(user.user_id)): + matching_users.append(user) + + # Sort by user_id and paginate + matching_users.sort(key=lambda u: u.user_id) + return matching_users[offset:offset + limit] + + async def get_user_count_total(self) -> int: + """Get total number of users (JSON implementation).""" + async with self._lock: + data = self._load_data() + # Count all keys except "servers" + return len([k for k in data.keys() if k != "servers"]) + async def close(self): """Close database (no-op for JSON database).""" self.logger.info("JSON database closed") @@ -284,101 +322,132 @@ class MariaDBDatabase: } self.logger = logging.getLogger(__name__) self.pool = None - self._lock = asyncio.Lock() # Table schema versions self.schema_version = 1 async def initialize(self): - """Initialize database connection and ensure tables exist.""" + """Initialize database connection pool and ensure tables exist.""" try: - # Add DictCursor to config for dictionary results - self.db_config['cursor_cls'] = DictCursor - self.pool = await connect(**self.db_config) + # Create connection pool instead of single connection + self.pool = await create_pool( + host=self.db_config['host'], + port=self.db_config['port'], + user=self.db_config['user'], + password=self.db_config['password'], + db=self.db_config['db'], + minsize=5, + maxsize=20, + charset='utf8mb4', + cursorclass=DictCursor + ) await self._create_tables() - self.logger.info("Database connection established") + await self._create_indexes() + self.logger.info("Database connection pool established") except MySQLError as e: self.logger.error(f"Database connection failed: {e}") raise async def _create_tables(self): """Create necessary tables if they don't exist.""" - async with self.pool.cursor() as cursor: - await cursor.execute(""" - CREATE TABLE IF NOT EXISTS users ( - user_id BIGINT PRIMARY KEY, - username VARCHAR(32) NOT NULL, - discriminator VARCHAR(4) NOT NULL, - display_name VARCHAR(32), - avatar_url VARCHAR(255), - banner_url VARCHAR(255), - bio TEXT, - status VARCHAR(20), - activity VARCHAR(50), - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP - ) - """) + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + user_id BIGINT PRIMARY KEY, + username VARCHAR(32) NOT NULL, + discriminator VARCHAR(4) NOT NULL, + display_name VARCHAR(32), + avatar_url VARCHAR(255), + banner_url VARCHAR(255), + bio TEXT, + status VARCHAR(20), + activity VARCHAR(50), + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) + """) + + await cursor.execute(""" + CREATE TABLE IF NOT EXISTS servers ( + server_id BIGINT PRIMARY KEY, + server_name VARCHAR(100) NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) + """) + + await cursor.execute(""" + CREATE TABLE IF NOT EXISTS user_servers ( + user_id BIGINT, + server_id BIGINT, + PRIMARY KEY (user_id, server_id), + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE, + FOREIGN KEY (server_id) REFERENCES servers(server_id) ON DELETE CASCADE + ) + """) + + await cursor.execute(""" + CREATE TABLE IF NOT EXISTS schema_version ( + version INT PRIMARY KEY, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) - await cursor.execute(""" - CREATE TABLE IF NOT EXISTS servers ( - server_id BIGINT PRIMARY KEY, - server_name VARCHAR(100) NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP - ) - """) - - await cursor.execute(""" - CREATE TABLE IF NOT EXISTS user_servers ( - user_id BIGINT, - server_id BIGINT, - PRIMARY KEY (user_id, server_id), - FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE, - FOREIGN KEY (server_id) REFERENCES servers(server_id) ON DELETE CASCADE - ) - """) - - await cursor.execute(""" - CREATE TABLE IF NOT EXISTS schema_version ( - version INT PRIMARY KEY, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - """) - - # Check schema version - await cursor.execute(""" - INSERT IGNORE INTO schema_version (version) VALUES (%s) - """, (self.schema_version,)) + # Check schema version + await cursor.execute(""" + INSERT IGNORE INTO schema_version (version) VALUES (%s) + """, (self.schema_version,)) + + async def _create_indexes(self): + """Create performance indexes.""" + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + # Create indexes for better performance + indexes = [ + "CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)", + "CREATE INDEX IF NOT EXISTS idx_users_display_name ON users(display_name)", + "CREATE INDEX IF NOT EXISTS idx_users_updated_at ON users(updated_at)", + "CREATE INDEX IF NOT EXISTS idx_user_servers_server_id ON user_servers(server_id)", + "CREATE INDEX IF NOT EXISTS idx_users_username_display ON users(username, display_name)", + "CREATE INDEX IF NOT EXISTS idx_servers_name ON servers(server_name)", + ] + + for index_sql in indexes: + try: + await cursor.execute(index_sql) + self.logger.debug(f"Created index: {index_sql}") + except MySQLError as e: + if "Duplicate key name" not in str(e): + self.logger.warning(f"Failed to create index: {e}") + + # Add full-text search index + try: + await cursor.execute(""" + ALTER TABLE users ADD FULLTEXT(username, display_name, bio) + """) + except MySQLError as e: + if "Duplicate key name" not in str(e): + self.logger.debug(f"Full-text index already exists or failed: {e}") self.logger.info("Database tables verified/created") - async def _execute_query(self, query: str, params: tuple = None): - """Execute a database query with error handling.""" - async with self._lock: - try: - async with self.pool.cursor() as cursor: - await cursor.execute(query, params) - return cursor - except MySQLError as e: - self.logger.error(f"Database error: {e}") - raise - async def get_user(self, user_id: int) -> Optional[UserData]: """Get user data by ID with associated servers.""" - async with self.pool.cursor() as cursor: - await cursor.execute(""" - SELECT u.*, GROUP_CONCAT(us.server_id) AS servers - FROM users u - LEFT JOIN user_servers us ON u.user_id = us.user_id - WHERE u.user_id = %s - GROUP BY u.user_id - """, (user_id,)) - - result = await cursor.fetchone() - if result: - return self._parse_user_result(result) - return None + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + LEFT JOIN user_servers us ON u.user_id = us.user_id + WHERE u.user_id = %s + GROUP BY u.user_id + """, (user_id,)) + + result = await cursor.fetchone() + if result: + return self._parse_user_result(result) + return None def _parse_user_result(self, result: Dict) -> UserData: """Convert database result to UserData object.""" @@ -400,64 +469,65 @@ class MariaDBDatabase: async def save_user(self, user_data: UserData): """Save or update user data with transaction.""" - async with self.pool.cursor() as cursor: - try: - # Start transaction - await cursor.execute("START TRANSACTION") - - # Upsert user data - await cursor.execute(""" - INSERT INTO users ( - user_id, username, discriminator, display_name, - avatar_url, banner_url, bio, status, activity - ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) - ON DUPLICATE KEY UPDATE - username = VALUES(username), - discriminator = VALUES(discriminator), - display_name = VALUES(display_name), - avatar_url = VALUES(avatar_url), - banner_url = VALUES(banner_url), - bio = VALUES(bio), - status = VALUES(status), - activity = VALUES(activity) - """, ( - user_data.user_id, - user_data.username, - user_data.discriminator, - user_data.display_name, - user_data.avatar_url, - user_data.banner_url, - user_data.bio, - user_data.status, - user_data.activity - )) + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + try: + # Start transaction + await cursor.execute("START TRANSACTION") + + # Upsert user data + await cursor.execute(""" + INSERT INTO users ( + user_id, username, discriminator, display_name, + avatar_url, banner_url, bio, status, activity + ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + username = VALUES(username), + discriminator = VALUES(discriminator), + display_name = VALUES(display_name), + avatar_url = VALUES(avatar_url), + banner_url = VALUES(banner_url), + bio = VALUES(bio), + status = VALUES(status), + activity = VALUES(activity) + """, ( + user_data.user_id, + user_data.username, + user_data.discriminator, + user_data.display_name, + user_data.avatar_url, + user_data.banner_url, + user_data.bio, + user_data.status, + user_data.activity + )) - # Update servers relationship - await cursor.execute( - "DELETE FROM user_servers WHERE user_id = %s", - (user_data.user_id,) - ) - - if user_data.servers: - for server_id in user_data.servers: - await cursor.execute( - "INSERT IGNORE INTO user_servers (user_id, server_id) VALUES (%s, %s)", - (user_data.user_id, server_id) - ) - - # Commit transaction - await cursor.execute("COMMIT") - self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}") + # Update servers relationship + await cursor.execute( + "DELETE FROM user_servers WHERE user_id = %s", + (user_data.user_id,) + ) + + if user_data.servers: + for server_id in user_data.servers: + await cursor.execute( + "INSERT IGNORE INTO user_servers (user_id, server_id) VALUES (%s, %s)", + (user_data.user_id, server_id) + ) + + # Commit transaction + await cursor.execute("COMMIT") + self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}") - except MySQLError as e: - self.logger.error(f"Error saving user: {e}") - await cursor.execute("ROLLBACK") - raise + except MySQLError as e: + self.logger.error(f"Error saving user: {e}") + await cursor.execute("ROLLBACK") + raise async def save_server(self, server_id: int, server_name: str): """Save server information.""" - async with self._lock: - async with self.pool.cursor() as cursor: + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: await cursor.execute(""" INSERT INTO servers (server_id, server_name) VALUES (%s, %s) @@ -471,8 +541,8 @@ class MariaDBDatabase: if not server_ids: return {} - async with self._lock: - async with self.pool.cursor() as cursor: + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: placeholders = ','.join(['%s'] * len(server_ids)) await cursor.execute(f""" SELECT server_id, server_name @@ -485,8 +555,8 @@ class MariaDBDatabase: async def add_server_to_user(self, user_id: int, server_id: int): """Add a server to user's server list.""" - async with self._lock: - async with self.pool.cursor() as cursor: + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: await cursor.execute(""" INSERT IGNORE INTO user_servers (user_id, server_id) VALUES (%s, %s) @@ -494,42 +564,175 @@ class MariaDBDatabase: async def get_all_users(self) -> List[UserData]: """Get all users from the database.""" - async with self.pool.cursor() as cursor: - await cursor.execute(""" - SELECT u.*, GROUP_CONCAT(us.server_id) AS servers - FROM users u - LEFT JOIN user_servers us ON u.user_id = us.user_id - GROUP BY u.user_id - """) - results = await cursor.fetchall() - return [self._parse_user_result(r) for r in results] + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + LEFT JOIN user_servers us ON u.user_id = us.user_id + GROUP BY u.user_id + """) + results = await cursor.fetchall() + return [self._parse_user_result(r) for r in results] + + async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]: + """Get users with pagination.""" + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + LEFT JOIN user_servers us ON u.user_id = us.user_id + GROUP BY u.user_id + ORDER BY u.user_id + LIMIT %s OFFSET %s + """, (limit, offset)) + results = await cursor.fetchall() + return [self._parse_user_result(r) for r in results] + + async def search_users(self, query: str, offset: int = 0, limit: int = 100) -> List[UserData]: + """Search users with pagination.""" + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + # Try full-text search first + try: + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + LEFT JOIN user_servers us ON u.user_id = us.user_id + WHERE MATCH(u.username, u.display_name, u.bio) AGAINST(%s IN BOOLEAN MODE) + GROUP BY u.user_id + ORDER BY u.user_id + LIMIT %s OFFSET %s + """, (f"*{query}*", limit, offset)) + results = await cursor.fetchall() + if results: + return [self._parse_user_result(r) for r in results] + except MySQLError: + # Fall back to LIKE search if full-text fails + pass + + # Fallback to LIKE search + search_pattern = f"%{query}%" + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + LEFT JOIN user_servers us ON u.user_id = us.user_id + WHERE u.username LIKE %s OR u.display_name LIKE %s OR u.user_id = %s + GROUP BY u.user_id + ORDER BY u.user_id + LIMIT %s OFFSET %s + """, (search_pattern, search_pattern, query if query.isdigit() else None, limit, offset)) + results = await cursor.fetchall() + return [self._parse_user_result(r) for r in results] + + async def get_user_count_total(self) -> int: + """Get total number of users.""" + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute("SELECT COUNT(*) as count FROM users") + result = await cursor.fetchone() + return result['count'] if result else 0 + + async def save_users_batch(self, users: List[UserData]): + """Save multiple users in a batch operation.""" + if not users: + return + + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + try: + await cursor.execute("START TRANSACTION") + + # Prepare batch data for users + user_data = [] + server_data = [] + + for user in users: + user_data.append(( + user.user_id, + user.username, + user.discriminator, + user.display_name, + user.avatar_url, + user.banner_url, + user.bio, + user.status, + user.activity + )) + + # Collect server relationships + for server_id in user.servers: + server_data.append((user.user_id, server_id)) + + # Batch insert users + await cursor.executemany(""" + INSERT INTO users (user_id, username, discriminator, display_name, + avatar_url, banner_url, bio, status, activity) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + username = VALUES(username), + discriminator = VALUES(discriminator), + display_name = VALUES(display_name), + avatar_url = VALUES(avatar_url), + banner_url = VALUES(banner_url), + bio = VALUES(bio), + status = VALUES(status), + activity = VALUES(activity) + """, user_data) + + # Clear existing server relationships for these users + user_ids = [user.user_id for user in users] + if user_ids: + placeholders = ','.join(['%s'] * len(user_ids)) + await cursor.execute(f""" + DELETE FROM user_servers WHERE user_id IN ({placeholders}) + """, user_ids) + + # Batch insert server relationships + if server_data: + await cursor.executemany(""" + INSERT IGNORE INTO user_servers (user_id, server_id) + VALUES (%s, %s) + """, server_data) + + await cursor.execute("COMMIT") + self.logger.info(f"Batch saved {len(users)} users") + + except MySQLError as e: + await cursor.execute("ROLLBACK") + self.logger.error(f"Error in batch save: {e}") + raise async def get_users_by_server(self, server_id: int) -> List[UserData]: """Get all users that are members of a specific server.""" - async with self.pool.cursor() as cursor: - await cursor.execute(""" - SELECT u.*, GROUP_CONCAT(us.server_id) AS servers - FROM users u - JOIN user_servers us ON u.user_id = us.user_id - WHERE us.server_id = %s - GROUP BY u.user_id - """, (server_id,)) - results = await cursor.fetchall() - return [self._parse_user_result(r) for r in results] + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + JOIN user_servers us ON u.user_id = us.user_id + WHERE us.server_id = %s + GROUP BY u.user_id + """, (server_id,)) + results = await cursor.fetchall() + return [self._parse_user_result(r) for r in results] async def get_user_count(self) -> int: """Get total number of users in database.""" - async with self.pool.cursor() as cursor: - await cursor.execute("SELECT COUNT(*) as user_count FROM users") - result = await cursor.fetchone() - return result['user_count'] if result else 0 + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute("SELECT COUNT(*) as user_count FROM users") + result = await cursor.fetchone() + return result['user_count'] if result else 0 async def get_server_count(self) -> int: """Get total number of unique servers.""" - async with self.pool.cursor() as cursor: - await cursor.execute("SELECT COUNT(DISTINCT server_id) as server_count FROM user_servers") - result = await cursor.fetchone() - return result['server_count'] if result else 0 + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute("SELECT COUNT(DISTINCT server_id) as server_count FROM user_servers") + result = await cursor.fetchone() + return result['server_count'] if result else 0 async def get_statistics(self) -> Dict[str, Any]: """Get database statistics using optimized queries.""" @@ -538,30 +741,48 @@ class MariaDBDatabase: 'total_servers': await self.get_server_count(), } - async with self.pool.cursor() as cursor: - await cursor.execute(""" - SELECT server_id, COUNT(user_id) as user_count - FROM user_servers - GROUP BY server_id - ORDER BY user_count DESC - LIMIT 10 - """) - most_active = await cursor.fetchall() - # Convert to list of tuples for consistency with JSON version - stats['most_active_servers'] = [(row['server_id'], row['user_count']) for row in most_active] - - # Get database size - await cursor.execute(""" - SELECT - ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS database_size_mb - FROM information_schema.tables - WHERE table_schema = DATABASE() - """) - size_result = await cursor.fetchone() - stats['database_size'] = int((size_result['database_size_mb'] or 0) * 1024 * 1024) # Convert to bytes + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + SELECT server_id, COUNT(user_id) as user_count + FROM user_servers + GROUP BY server_id + ORDER BY user_count DESC + LIMIT 10 + """) + most_active = await cursor.fetchall() + # Convert to list of tuples for consistency with JSON version + stats['most_active_servers'] = [(row['server_id'], row['user_count']) for row in most_active] + + # Get database size + await cursor.execute(""" + SELECT + ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS database_size_mb + FROM information_schema.tables + WHERE table_schema = DATABASE() + """) + size_result = await cursor.fetchone() + stats['database_size'] = int((size_result['database_size_mb'] or 0) * 1024 * 1024) # Convert to bytes return stats + async def export_to_csv(self, output_path: str): + """Export data to CSV format.""" + import csv + users = await self.get_all_users() + + with open(output_path, 'w', newline='', encoding='utf-8') as csvfile: + fieldnames = ['user_id', 'username', 'discriminator', 'display_name', + 'avatar_url', 'banner_url', 'bio', 'status', 'activity', + 'servers', 'created_at', 'updated_at'] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + for user in users: + row = user.to_dict() + row['servers'] = ','.join(map(str, user.servers)) + writer.writerow(row) + async def cleanup_old_backups(self, max_backups: int = 5): """Clean up old backup documents (placeholder for MariaDB).""" # For MariaDB, we could implement this as cleaning up old backup tables diff --git a/src/rate_limiter.py b/src/rate_limiter.py index 44629c3..90e87fe 100644 --- a/src/rate_limiter.py +++ b/src/rate_limiter.py @@ -9,6 +9,72 @@ from typing import Optional import logging +class AdaptiveRateLimiter: + """Adaptive rate limiter that adjusts to Discord API rate limits.""" + + def __init__(self, base_delay: float = 1.0, max_delay: float = 60.0): + self.base_delay = base_delay + self.max_delay = max_delay + self.current_delay = base_delay + self.consecutive_429s = 0 + self.last_request_time: Optional[float] = None + self.success_count = 0 + self.logger = logging.getLogger(__name__) + + async def wait(self): + """Wait with adaptive delay based on rate limit responses.""" + current_time = time.time() + + # Check if we need to wait based on current delay + if self.last_request_time is not None: + time_since_last_request = current_time - self.last_request_time + if time_since_last_request < self.current_delay: + wait_time = self.current_delay - time_since_last_request + self.logger.debug(f"Adaptive rate limit wait: {wait_time:.2f}s") + await asyncio.sleep(wait_time) + + self.last_request_time = time.time() + + def on_success(self): + """Called when a request succeeds.""" + self.success_count += 1 + + # After 10 successful requests, try to reduce delay + if self.success_count >= 10: + old_delay = self.current_delay + self.current_delay = max(self.base_delay, self.current_delay * 0.9) + self.consecutive_429s = 0 + self.success_count = 0 + if old_delay != self.current_delay: + self.logger.debug(f"Reduced delay to {self.current_delay:.2f}s") + + def on_rate_limit(self, retry_after: Optional[float] = None): + """Called when a 429 rate limit is encountered.""" + self.consecutive_429s += 1 + self.success_count = 0 + + if retry_after: + # Use Discord's suggested retry time + self.current_delay = min(self.max_delay, retry_after) + self.logger.warning(f"Rate limited, using Discord's retry_after: {retry_after}s") + else: + # Exponential backoff + old_delay = self.current_delay + self.current_delay = min(self.max_delay, self.base_delay * (2 ** self.consecutive_429s)) + self.logger.warning(f"Rate limited, increased delay from {old_delay:.2f}s to {self.current_delay:.2f}s") + + def get_stats(self) -> dict: + """Get rate limiter statistics.""" + current_time = time.time() + return { + 'current_delay': self.current_delay, + 'base_delay': self.base_delay, + 'consecutive_429s': self.consecutive_429s, + 'success_count': self.success_count, + 'time_since_last_request': current_time - self.last_request_time if self.last_request_time else 0 + } + + class RateLimiter: """Rate limiter to prevent hitting Discord API limits."""