i have no sweet clue if this works

This commit is contained in:
glitchy 2025-07-13 21:30:12 +02:00
parent a9bcce85d6
commit 5dbc12e943

View file

@ -1,15 +1,14 @@
""" """
JSON database manager for Discord user data storage. MariaDB database manager for Discord user data storage.
""" """
import json
import asyncio import asyncio
import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict from dataclasses import dataclass
import logging import logging
from asyncmy import connect, Connection, Cursor
from asyncmy.errors import MySQLError
@dataclass @dataclass
@ -25,181 +24,257 @@ class UserData:
status: Optional[str] = None status: Optional[str] = None
activity: Optional[str] = None activity: Optional[str] = None
servers: List[int] = None servers: List[int] = None
created_at: str = None created_at: datetime = None
updated_at: str = None updated_at: datetime = None
def __post_init__(self): def __post_init__(self):
if self.servers is None: if self.servers is None:
self.servers = [] self.servers = []
current_time = datetime.utcnow()
current_time = datetime.utcnow().isoformat()
if self.created_at is None: if self.created_at is None:
self.created_at = current_time self.created_at = current_time
self.updated_at = current_time self.updated_at = current_time
class JSONDatabase: class MariaDBDatabase:
"""JSON-based database for storing Discord user data.""" """MariaDB-based database for storing Discord user data."""
def __init__(self, database_path: str): def __init__(self,
"""Initialize the JSON database.""" host: str,
self.database_path = Path(database_path) user: str,
self.backup_path = Path("data/backups") password: str,
database: str,
port: int = 3306):
"""Initialize the MariaDB connection."""
self.db_config = {
'host': host,
'port': port,
'user': user,
'password': password,
'db': database,
}
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.pool = None
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._data: Dict[str, Dict] = {}
# Ensure database directory exists # Table schema versions
self.database_path.parent.mkdir(parents=True, exist_ok=True) self.schema_version = 1
self.backup_path.mkdir(parents=True, exist_ok=True)
# Load existing data async def initialize(self):
self._load_data() """Initialize database connection and ensure tables exist."""
def _load_data(self):
"""Load data from JSON file."""
if self.database_path.exists():
try: try:
with open(self.database_path, 'r', encoding='utf-8') as f: self.pool = await connect(**self.db_config)
self._data = json.load(f) await self._create_tables()
self.logger.info(f"Loaded {len(self._data)} users from database") self.logger.info("Database connection established")
except Exception as e: except MySQLError as e:
self.logger.error(f"Error loading database: {e}") self.logger.error(f"Database connection failed: {e}")
self._data = {} raise
else:
self._data = {}
self.logger.info("Created new database")
async def _save_data(self): async def _create_tables(self):
"""Save data to JSON file.""" """Create necessary tables if they don't exist."""
async with self.pool.cursor() as cursor:
await cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
user_id BIGINT PRIMARY KEY,
username VARCHAR(32) NOT NULL,
discriminator VARCHAR(4) NOT NULL,
display_name VARCHAR(32),
avatar_url VARCHAR(255),
banner_url VARCHAR(255),
bio TEXT,
status VARCHAR(20),
activity VARCHAR(50),
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
)
""")
await cursor.execute("""
CREATE TABLE IF NOT EXISTS user_servers (
user_id BIGINT,
server_id BIGINT,
PRIMARY KEY (user_id, server_id),
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
)
""")
await cursor.execute("""
CREATE TABLE IF NOT EXISTS schema_version (
version INT PRIMARY KEY,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# Check schema version
await cursor.execute("""
INSERT IGNORE INTO schema_version (version) VALUES (%s)
""", (self.schema_version,))
self.logger.info("Database tables verified/created")
async def _execute_query(self, query: str, params: tuple = None):
"""Execute a database query with error handling."""
async with self._lock: async with self._lock:
try: try:
# Create backup before saving async with self.pool.cursor() as cursor:
if self.database_path.exists(): await cursor.execute(query, params)
backup_filename = f"users_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" return cursor
backup_path = self.backup_path / backup_filename except MySQLError as e:
shutil.copy2(self.database_path, backup_path) self.logger.error(f"Database error: {e}")
raise
# Save data
with open(self.database_path, 'w', encoding='utf-8') as f:
json.dump(self._data, f, indent=2, ensure_ascii=False)
self.logger.debug(f"Saved {len(self._data)} users to database")
except Exception as e:
self.logger.error(f"Error saving database: {e}")
async def get_user(self, user_id: int) -> Optional[UserData]: async def get_user(self, user_id: int) -> Optional[UserData]:
"""Get user data by ID.""" """Get user data by ID with associated servers."""
user_key = str(user_id) async with self.pool.cursor() as cursor:
if user_key in self._data: await cursor.execute("""
user_dict = self._data[user_key] SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
return UserData(**user_dict) FROM users u
LEFT JOIN user_servers us ON u.user_id = us.user_id
WHERE u.user_id = %s
GROUP BY u.user_id
""", (user_id,))
result = await cursor.fetchone()
if result:
return self._parse_user_result(result)
return None return None
def _parse_user_result(self, result: Dict) -> UserData:
"""Convert database result to UserData object."""
servers = list(map(int, result['servers'].split(','))) if result['servers'] else []
return UserData(
user_id=result['user_id'],
username=result['username'],
discriminator=result['discriminator'],
display_name=result['display_name'],
avatar_url=result['avatar_url'],
banner_url=result['banner_url'],
bio=result['bio'],
status=result['status'],
activity=result['activity'],
servers=servers,
created_at=result['created_at'],
updated_at=result['updated_at']
)
async def save_user(self, user_data: UserData): async def save_user(self, user_data: UserData):
"""Save or update user data.""" """Save or update user data with transaction."""
user_key = str(user_data.user_id) async with self.pool.begin() as conn:
try:
# Upsert user data
await conn.execute("""
INSERT INTO users (
user_id, username, discriminator, display_name,
avatar_url, banner_url, bio, status, activity
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
username = VALUES(username),
discriminator = VALUES(discriminator),
display_name = VALUES(display_name),
avatar_url = VALUES(avatar_url),
banner_url = VALUES(banner_url),
bio = VALUES(bio),
status = VALUES(status),
activity = VALUES(activity)
""", (
user_data.user_id,
user_data.username,
user_data.discriminator,
user_data.display_name,
user_data.avatar_url,
user_data.banner_url,
user_data.bio,
user_data.status,
user_data.activity
))
# If user exists, preserve created_at timestamp # Update servers relationship
if user_key in self._data: await conn.execute(
user_data.created_at = self._data[user_key]['created_at'] "DELETE FROM user_servers WHERE user_id = %s",
(user_data.user_id,)
)
# Update timestamp if user_data.servers:
user_data.updated_at = datetime.utcnow().isoformat() server_values = [(user_data.user_id, s) for s in user_data.servers]
await conn.executemany(
"INSERT INTO user_servers (user_id, server_id) VALUES (%s, %s)",
server_values
)
# Save to memory self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}")
self._data[user_key] = asdict(user_data)
# Save to disk except MySQLError as e:
await self._save_data() self.logger.error(f"Error saving user: {e}")
await conn.rollback()
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator} ({user_data.user_id})") raise
async def add_server_to_user(self, user_id: int, server_id: int): async def add_server_to_user(self, user_id: int, server_id: int):
"""Add a server to user's server list.""" """Add a server to user's server list."""
user_key = str(user_id) await self._execute_query("""
if user_key in self._data: INSERT IGNORE INTO user_servers (user_id, server_id)
if server_id not in self._data[user_key]['servers']: VALUES (%s, %s)
self._data[user_key]['servers'].append(server_id) """, (user_id, server_id))
self._data[user_key]['updated_at'] = datetime.utcnow().isoformat()
await self._save_data()
async def get_all_users(self) -> List[UserData]: async def get_all_users(self) -> List[UserData]:
"""Get all users from the database.""" """Get all users from the database."""
return [UserData(**user_dict) for user_dict in self._data.values()] async with self.pool.cursor() as cursor:
await cursor.execute("""
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
FROM users u
LEFT JOIN user_servers us ON u.user_id = us.user_id
GROUP BY u.user_id
""")
results = await cursor.fetchall()
return [self._parse_user_result(r) for r in results]
async def get_users_by_server(self, server_id: int) -> List[UserData]: async def get_users_by_server(self, server_id: int) -> List[UserData]:
"""Get all users that are members of a specific server.""" """Get all users that are members of a specific server."""
users = [] async with self.pool.cursor() as cursor:
for user_dict in self._data.values(): await cursor.execute("""
if server_id in user_dict.get('servers', []): SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
users.append(UserData(**user_dict)) FROM users u
return users JOIN user_servers us ON u.user_id = us.user_id
WHERE us.server_id = %s
GROUP BY u.user_id
""", (server_id,))
results = await cursor.fetchall()
return [self._parse_user_result(r) for r in results]
async def get_user_count(self) -> int: async def get_user_count(self) -> int:
"""Get total number of users in database.""" """Get total number of users in database."""
return len(self._data) async with self.pool.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM users")
result = await cursor.fetchone()
return result['COUNT(*)']
async def get_server_count(self) -> int: async def get_server_count(self) -> int:
"""Get total number of unique servers.""" """Get total number of unique servers."""
servers = set() async with self.pool.cursor() as cursor:
for user_dict in self._data.values(): await cursor.execute("SELECT COUNT(DISTINCT server_id) FROM user_servers")
servers.update(user_dict.get('servers', [])) result = await cursor.fetchone()
return len(servers) return result['COUNT(DISTINCT server_id)']
async def cleanup_old_backups(self, max_backups: int = 10):
"""Clean up old backup files, keeping only the most recent ones."""
backup_files = sorted(self.backup_path.glob("users_backup_*.json"))
if len(backup_files) > max_backups:
files_to_remove = backup_files[:-max_backups]
for file_path in files_to_remove:
try:
file_path.unlink()
self.logger.info(f"Removed old backup: {file_path.name}")
except Exception as e:
self.logger.error(f"Error removing backup {file_path.name}: {e}")
async def export_to_csv(self, output_path: str):
"""Export user data to CSV format."""
import csv
output_path = Path(output_path)
try:
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['user_id', 'username', 'discriminator', 'display_name',
'avatar_url', 'bio', 'status', 'servers', 'created_at', 'updated_at']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for user_dict in self._data.values():
# Convert servers list to string
user_dict_copy = user_dict.copy()
user_dict_copy['servers'] = ','.join(map(str, user_dict.get('servers', [])))
writer.writerow(user_dict_copy)
self.logger.info(f"Exported {len(self._data)} users to {output_path}")
except Exception as e:
self.logger.error(f"Error exporting to CSV: {e}")
async def get_statistics(self) -> Dict[str, Any]: async def get_statistics(self) -> Dict[str, Any]:
"""Get database statistics.""" """Get database statistics using optimized queries."""
stats = { stats = {
'total_users': await self.get_user_count(), 'total_users': await self.get_user_count(),
'total_servers': await self.get_server_count(), 'total_servers': await self.get_server_count(),
'database_size': self.database_path.stat().st_size if self.database_path.exists() else 0
} }
# Most active servers async with self.pool.cursor() as cursor:
server_counts = {} await cursor.execute("""
for user_dict in self._data.values(): SELECT server_id, COUNT(user_id) as user_count
for server_id in user_dict.get('servers', []): FROM user_servers
server_counts[server_id] = server_counts.get(server_id, 0) + 1 GROUP BY server_id
ORDER BY user_count DESC
stats['most_active_servers'] = sorted(server_counts.items(), LIMIT 10
key=lambda x: x[1], reverse=True)[:10] """)
stats['most_active_servers'] = await cursor.fetchall()
return stats return stats
async def close(self):
"""Close database connection."""
if self.pool:
await self.pool.close()
self.logger.info("Database connection closed")