diff --git a/src/database.py b/src/database.py index 24ac105..a86d1c6 100644 --- a/src/database.py +++ b/src/database.py @@ -1,15 +1,14 @@ """ -JSON database manager for Discord user data storage. +MariaDB database manager for Discord user data storage. """ -import json import asyncio -import shutil from datetime import datetime -from pathlib import Path from typing import Dict, List, Optional, Any -from dataclasses import dataclass, asdict +from dataclasses import dataclass import logging +from asyncmy import connect, Connection, Cursor +from asyncmy.errors import MySQLError @dataclass @@ -25,181 +24,257 @@ class UserData: status: Optional[str] = None activity: Optional[str] = None servers: List[int] = None - created_at: str = None - updated_at: str = None - + created_at: datetime = None + updated_at: datetime = None + def __post_init__(self): if self.servers is None: self.servers = [] - - current_time = datetime.utcnow().isoformat() + current_time = datetime.utcnow() if self.created_at is None: self.created_at = current_time self.updated_at = current_time -class JSONDatabase: - """JSON-based database for storing Discord user data.""" +class MariaDBDatabase: + """MariaDB-based database for storing Discord user data.""" - def __init__(self, database_path: str): - """Initialize the JSON database.""" - self.database_path = Path(database_path) - self.backup_path = Path("data/backups") + def __init__(self, + host: str, + user: str, + password: str, + database: str, + port: int = 3306): + """Initialize the MariaDB connection.""" + self.db_config = { + 'host': host, + 'port': port, + 'user': user, + 'password': password, + 'db': database, + } self.logger = logging.getLogger(__name__) + self.pool = None self._lock = asyncio.Lock() - self._data: Dict[str, Dict] = {} - # Ensure database directory exists - self.database_path.parent.mkdir(parents=True, exist_ok=True) - self.backup_path.mkdir(parents=True, exist_ok=True) - - # Load existing data - self._load_data() - - def _load_data(self): - """Load data from JSON file.""" - if self.database_path.exists(): - try: - with open(self.database_path, 'r', encoding='utf-8') as f: - self._data = json.load(f) - self.logger.info(f"Loaded {len(self._data)} users from database") - except Exception as e: - self.logger.error(f"Error loading database: {e}") - self._data = {} - else: - self._data = {} - self.logger.info("Created new database") - - async def _save_data(self): - """Save data to JSON file.""" + # Table schema versions + self.schema_version = 1 + + async def initialize(self): + """Initialize database connection and ensure tables exist.""" + try: + self.pool = await connect(**self.db_config) + await self._create_tables() + self.logger.info("Database connection established") + except MySQLError as e: + self.logger.error(f"Database connection failed: {e}") + raise + + async def _create_tables(self): + """Create necessary tables if they don't exist.""" + async with self.pool.cursor() as cursor: + await cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + user_id BIGINT PRIMARY KEY, + username VARCHAR(32) NOT NULL, + discriminator VARCHAR(4) NOT NULL, + display_name VARCHAR(32), + avatar_url VARCHAR(255), + banner_url VARCHAR(255), + bio TEXT, + status VARCHAR(20), + activity VARCHAR(50), + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + ) + """) + + await cursor.execute(""" + CREATE TABLE IF NOT EXISTS user_servers ( + user_id BIGINT, + server_id BIGINT, + PRIMARY KEY (user_id, server_id), + FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE + ) + """) + + await cursor.execute(""" + CREATE TABLE IF NOT EXISTS schema_version ( + version INT PRIMARY KEY, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Check schema version + await cursor.execute(""" + INSERT IGNORE INTO schema_version (version) VALUES (%s) + """, (self.schema_version,)) + + self.logger.info("Database tables verified/created") + + async def _execute_query(self, query: str, params: tuple = None): + """Execute a database query with error handling.""" async with self._lock: try: - # Create backup before saving - if self.database_path.exists(): - backup_filename = f"users_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - backup_path = self.backup_path / backup_filename - shutil.copy2(self.database_path, backup_path) - - # Save data - with open(self.database_path, 'w', encoding='utf-8') as f: - json.dump(self._data, f, indent=2, ensure_ascii=False) - - self.logger.debug(f"Saved {len(self._data)} users to database") - - except Exception as e: - self.logger.error(f"Error saving database: {e}") - + async with self.pool.cursor() as cursor: + await cursor.execute(query, params) + return cursor + except MySQLError as e: + self.logger.error(f"Database error: {e}") + raise + async def get_user(self, user_id: int) -> Optional[UserData]: - """Get user data by ID.""" - user_key = str(user_id) - if user_key in self._data: - user_dict = self._data[user_key] - return UserData(**user_dict) - return None - + """Get user data by ID with associated servers.""" + async with self.pool.cursor() as cursor: + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + LEFT JOIN user_servers us ON u.user_id = us.user_id + WHERE u.user_id = %s + GROUP BY u.user_id + """, (user_id,)) + + result = await cursor.fetchone() + if result: + return self._parse_user_result(result) + return None + + def _parse_user_result(self, result: Dict) -> UserData: + """Convert database result to UserData object.""" + servers = list(map(int, result['servers'].split(','))) if result['servers'] else [] + return UserData( + user_id=result['user_id'], + username=result['username'], + discriminator=result['discriminator'], + display_name=result['display_name'], + avatar_url=result['avatar_url'], + banner_url=result['banner_url'], + bio=result['bio'], + status=result['status'], + activity=result['activity'], + servers=servers, + created_at=result['created_at'], + updated_at=result['updated_at'] + ) + async def save_user(self, user_data: UserData): - """Save or update user data.""" - user_key = str(user_data.user_id) - - # If user exists, preserve created_at timestamp - if user_key in self._data: - user_data.created_at = self._data[user_key]['created_at'] - - # Update timestamp - user_data.updated_at = datetime.utcnow().isoformat() - - # Save to memory - self._data[user_key] = asdict(user_data) - - # Save to disk - await self._save_data() - - self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator} ({user_data.user_id})") - + """Save or update user data with transaction.""" + async with self.pool.begin() as conn: + try: + # Upsert user data + await conn.execute(""" + INSERT INTO users ( + user_id, username, discriminator, display_name, + avatar_url, banner_url, bio, status, activity + ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + username = VALUES(username), + discriminator = VALUES(discriminator), + display_name = VALUES(display_name), + avatar_url = VALUES(avatar_url), + banner_url = VALUES(banner_url), + bio = VALUES(bio), + status = VALUES(status), + activity = VALUES(activity) + """, ( + user_data.user_id, + user_data.username, + user_data.discriminator, + user_data.display_name, + user_data.avatar_url, + user_data.banner_url, + user_data.bio, + user_data.status, + user_data.activity + )) + + # Update servers relationship + await conn.execute( + "DELETE FROM user_servers WHERE user_id = %s", + (user_data.user_id,) + ) + + if user_data.servers: + server_values = [(user_data.user_id, s) for s in user_data.servers] + await conn.executemany( + "INSERT INTO user_servers (user_id, server_id) VALUES (%s, %s)", + server_values + ) + + self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}") + + except MySQLError as e: + self.logger.error(f"Error saving user: {e}") + await conn.rollback() + raise + async def add_server_to_user(self, user_id: int, server_id: int): """Add a server to user's server list.""" - user_key = str(user_id) - if user_key in self._data: - if server_id not in self._data[user_key]['servers']: - self._data[user_key]['servers'].append(server_id) - self._data[user_key]['updated_at'] = datetime.utcnow().isoformat() - await self._save_data() - + await self._execute_query(""" + INSERT IGNORE INTO user_servers (user_id, server_id) + VALUES (%s, %s) + """, (user_id, server_id)) + async def get_all_users(self) -> List[UserData]: """Get all users from the database.""" - return [UserData(**user_dict) for user_dict in self._data.values()] - + async with self.pool.cursor() as cursor: + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + LEFT JOIN user_servers us ON u.user_id = us.user_id + GROUP BY u.user_id + """) + results = await cursor.fetchall() + return [self._parse_user_result(r) for r in results] + async def get_users_by_server(self, server_id: int) -> List[UserData]: """Get all users that are members of a specific server.""" - users = [] - for user_dict in self._data.values(): - if server_id in user_dict.get('servers', []): - users.append(UserData(**user_dict)) - return users - + async with self.pool.cursor() as cursor: + await cursor.execute(""" + SELECT u.*, GROUP_CONCAT(us.server_id) AS servers + FROM users u + JOIN user_servers us ON u.user_id = us.user_id + WHERE us.server_id = %s + GROUP BY u.user_id + """, (server_id,)) + results = await cursor.fetchall() + return [self._parse_user_result(r) for r in results] + async def get_user_count(self) -> int: """Get total number of users in database.""" - return len(self._data) - + async with self.pool.cursor() as cursor: + await cursor.execute("SELECT COUNT(*) FROM users") + result = await cursor.fetchone() + return result['COUNT(*)'] + async def get_server_count(self) -> int: """Get total number of unique servers.""" - servers = set() - for user_dict in self._data.values(): - servers.update(user_dict.get('servers', [])) - return len(servers) - - async def cleanup_old_backups(self, max_backups: int = 10): - """Clean up old backup files, keeping only the most recent ones.""" - backup_files = sorted(self.backup_path.glob("users_backup_*.json")) - - if len(backup_files) > max_backups: - files_to_remove = backup_files[:-max_backups] - for file_path in files_to_remove: - try: - file_path.unlink() - self.logger.info(f"Removed old backup: {file_path.name}") - except Exception as e: - self.logger.error(f"Error removing backup {file_path.name}: {e}") - - async def export_to_csv(self, output_path: str): - """Export user data to CSV format.""" - import csv - - output_path = Path(output_path) - - try: - with open(output_path, 'w', newline='', encoding='utf-8') as csvfile: - fieldnames = ['user_id', 'username', 'discriminator', 'display_name', - 'avatar_url', 'bio', 'status', 'servers', 'created_at', 'updated_at'] - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - - writer.writeheader() - for user_dict in self._data.values(): - # Convert servers list to string - user_dict_copy = user_dict.copy() - user_dict_copy['servers'] = ','.join(map(str, user_dict.get('servers', []))) - writer.writerow(user_dict_copy) - - self.logger.info(f"Exported {len(self._data)} users to {output_path}") - - except Exception as e: - self.logger.error(f"Error exporting to CSV: {e}") - + async with self.pool.cursor() as cursor: + await cursor.execute("SELECT COUNT(DISTINCT server_id) FROM user_servers") + result = await cursor.fetchone() + return result['COUNT(DISTINCT server_id)'] + async def get_statistics(self) -> Dict[str, Any]: - """Get database statistics.""" + """Get database statistics using optimized queries.""" stats = { 'total_users': await self.get_user_count(), 'total_servers': await self.get_server_count(), - 'database_size': self.database_path.stat().st_size if self.database_path.exists() else 0 } - - # Most active servers - server_counts = {} - for user_dict in self._data.values(): - for server_id in user_dict.get('servers', []): - server_counts[server_id] = server_counts.get(server_id, 0) + 1 - - stats['most_active_servers'] = sorted(server_counts.items(), - key=lambda x: x[1], reverse=True)[:10] - - return stats \ No newline at end of file + + async with self.pool.cursor() as cursor: + await cursor.execute(""" + SELECT server_id, COUNT(user_id) as user_count + FROM user_servers + GROUP BY server_id + ORDER BY user_count DESC + LIMIT 10 + """) + stats['most_active_servers'] = await cursor.fetchall() + + return stats + + async def close(self): + """Close database connection.""" + if self.pool: + await self.pool.close() + self.logger.info("Database connection closed")