i have no sweet clue if this works

This commit is contained in:
glitchy 2025-07-13 21:30:12 +02:00
parent a9bcce85d6
commit 5dbc12e943

View file

@ -1,15 +1,14 @@
"""
JSON database manager for Discord user data storage.
MariaDB database manager for Discord user data storage.
"""
import json
import asyncio
import shutil
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict
from dataclasses import dataclass
import logging
from asyncmy import connect, Connection, Cursor
from asyncmy.errors import MySQLError
@dataclass
@ -25,181 +24,257 @@ class UserData:
status: Optional[str] = None
activity: Optional[str] = None
servers: List[int] = None
created_at: str = None
updated_at: str = None
created_at: datetime = None
updated_at: datetime = None
def __post_init__(self):
if self.servers is None:
self.servers = []
current_time = datetime.utcnow().isoformat()
current_time = datetime.utcnow()
if self.created_at is None:
self.created_at = current_time
self.updated_at = current_time
class JSONDatabase:
"""JSON-based database for storing Discord user data."""
class MariaDBDatabase:
"""MariaDB-based database for storing Discord user data."""
def __init__(self, database_path: str):
"""Initialize the JSON database."""
self.database_path = Path(database_path)
self.backup_path = Path("data/backups")
def __init__(self,
host: str,
user: str,
password: str,
database: str,
port: int = 3306):
"""Initialize the MariaDB connection."""
self.db_config = {
'host': host,
'port': port,
'user': user,
'password': password,
'db': database,
}
self.logger = logging.getLogger(__name__)
self.pool = None
self._lock = asyncio.Lock()
self._data: Dict[str, Dict] = {}
# Ensure database directory exists
self.database_path.parent.mkdir(parents=True, exist_ok=True)
self.backup_path.mkdir(parents=True, exist_ok=True)
# Load existing data
self._load_data()
def _load_data(self):
"""Load data from JSON file."""
if self.database_path.exists():
try:
with open(self.database_path, 'r', encoding='utf-8') as f:
self._data = json.load(f)
self.logger.info(f"Loaded {len(self._data)} users from database")
except Exception as e:
self.logger.error(f"Error loading database: {e}")
self._data = {}
else:
self._data = {}
self.logger.info("Created new database")
async def _save_data(self):
"""Save data to JSON file."""
# Table schema versions
self.schema_version = 1
async def initialize(self):
"""Initialize database connection and ensure tables exist."""
try:
self.pool = await connect(**self.db_config)
await self._create_tables()
self.logger.info("Database connection established")
except MySQLError as e:
self.logger.error(f"Database connection failed: {e}")
raise
async def _create_tables(self):
"""Create necessary tables if they don't exist."""
async with self.pool.cursor() as cursor:
await cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
user_id BIGINT PRIMARY KEY,
username VARCHAR(32) NOT NULL,
discriminator VARCHAR(4) NOT NULL,
display_name VARCHAR(32),
avatar_url VARCHAR(255),
banner_url VARCHAR(255),
bio TEXT,
status VARCHAR(20),
activity VARCHAR(50),
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
)
""")
await cursor.execute("""
CREATE TABLE IF NOT EXISTS user_servers (
user_id BIGINT,
server_id BIGINT,
PRIMARY KEY (user_id, server_id),
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
)
""")
await cursor.execute("""
CREATE TABLE IF NOT EXISTS schema_version (
version INT PRIMARY KEY,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# Check schema version
await cursor.execute("""
INSERT IGNORE INTO schema_version (version) VALUES (%s)
""", (self.schema_version,))
self.logger.info("Database tables verified/created")
async def _execute_query(self, query: str, params: tuple = None):
"""Execute a database query with error handling."""
async with self._lock:
try:
# Create backup before saving
if self.database_path.exists():
backup_filename = f"users_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
backup_path = self.backup_path / backup_filename
shutil.copy2(self.database_path, backup_path)
# Save data
with open(self.database_path, 'w', encoding='utf-8') as f:
json.dump(self._data, f, indent=2, ensure_ascii=False)
self.logger.debug(f"Saved {len(self._data)} users to database")
except Exception as e:
self.logger.error(f"Error saving database: {e}")
async with self.pool.cursor() as cursor:
await cursor.execute(query, params)
return cursor
except MySQLError as e:
self.logger.error(f"Database error: {e}")
raise
async def get_user(self, user_id: int) -> Optional[UserData]:
"""Get user data by ID."""
user_key = str(user_id)
if user_key in self._data:
user_dict = self._data[user_key]
return UserData(**user_dict)
return None
"""Get user data by ID with associated servers."""
async with self.pool.cursor() as cursor:
await cursor.execute("""
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
FROM users u
LEFT JOIN user_servers us ON u.user_id = us.user_id
WHERE u.user_id = %s
GROUP BY u.user_id
""", (user_id,))
result = await cursor.fetchone()
if result:
return self._parse_user_result(result)
return None
def _parse_user_result(self, result: Dict) -> UserData:
"""Convert database result to UserData object."""
servers = list(map(int, result['servers'].split(','))) if result['servers'] else []
return UserData(
user_id=result['user_id'],
username=result['username'],
discriminator=result['discriminator'],
display_name=result['display_name'],
avatar_url=result['avatar_url'],
banner_url=result['banner_url'],
bio=result['bio'],
status=result['status'],
activity=result['activity'],
servers=servers,
created_at=result['created_at'],
updated_at=result['updated_at']
)
async def save_user(self, user_data: UserData):
"""Save or update user data."""
user_key = str(user_data.user_id)
# If user exists, preserve created_at timestamp
if user_key in self._data:
user_data.created_at = self._data[user_key]['created_at']
# Update timestamp
user_data.updated_at = datetime.utcnow().isoformat()
# Save to memory
self._data[user_key] = asdict(user_data)
# Save to disk
await self._save_data()
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator} ({user_data.user_id})")
"""Save or update user data with transaction."""
async with self.pool.begin() as conn:
try:
# Upsert user data
await conn.execute("""
INSERT INTO users (
user_id, username, discriminator, display_name,
avatar_url, banner_url, bio, status, activity
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
username = VALUES(username),
discriminator = VALUES(discriminator),
display_name = VALUES(display_name),
avatar_url = VALUES(avatar_url),
banner_url = VALUES(banner_url),
bio = VALUES(bio),
status = VALUES(status),
activity = VALUES(activity)
""", (
user_data.user_id,
user_data.username,
user_data.discriminator,
user_data.display_name,
user_data.avatar_url,
user_data.banner_url,
user_data.bio,
user_data.status,
user_data.activity
))
# Update servers relationship
await conn.execute(
"DELETE FROM user_servers WHERE user_id = %s",
(user_data.user_id,)
)
if user_data.servers:
server_values = [(user_data.user_id, s) for s in user_data.servers]
await conn.executemany(
"INSERT INTO user_servers (user_id, server_id) VALUES (%s, %s)",
server_values
)
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}")
except MySQLError as e:
self.logger.error(f"Error saving user: {e}")
await conn.rollback()
raise
async def add_server_to_user(self, user_id: int, server_id: int):
"""Add a server to user's server list."""
user_key = str(user_id)
if user_key in self._data:
if server_id not in self._data[user_key]['servers']:
self._data[user_key]['servers'].append(server_id)
self._data[user_key]['updated_at'] = datetime.utcnow().isoformat()
await self._save_data()
await self._execute_query("""
INSERT IGNORE INTO user_servers (user_id, server_id)
VALUES (%s, %s)
""", (user_id, server_id))
async def get_all_users(self) -> List[UserData]:
"""Get all users from the database."""
return [UserData(**user_dict) for user_dict in self._data.values()]
async with self.pool.cursor() as cursor:
await cursor.execute("""
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
FROM users u
LEFT JOIN user_servers us ON u.user_id = us.user_id
GROUP BY u.user_id
""")
results = await cursor.fetchall()
return [self._parse_user_result(r) for r in results]
async def get_users_by_server(self, server_id: int) -> List[UserData]:
"""Get all users that are members of a specific server."""
users = []
for user_dict in self._data.values():
if server_id in user_dict.get('servers', []):
users.append(UserData(**user_dict))
return users
async with self.pool.cursor() as cursor:
await cursor.execute("""
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
FROM users u
JOIN user_servers us ON u.user_id = us.user_id
WHERE us.server_id = %s
GROUP BY u.user_id
""", (server_id,))
results = await cursor.fetchall()
return [self._parse_user_result(r) for r in results]
async def get_user_count(self) -> int:
"""Get total number of users in database."""
return len(self._data)
async with self.pool.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM users")
result = await cursor.fetchone()
return result['COUNT(*)']
async def get_server_count(self) -> int:
"""Get total number of unique servers."""
servers = set()
for user_dict in self._data.values():
servers.update(user_dict.get('servers', []))
return len(servers)
async def cleanup_old_backups(self, max_backups: int = 10):
"""Clean up old backup files, keeping only the most recent ones."""
backup_files = sorted(self.backup_path.glob("users_backup_*.json"))
if len(backup_files) > max_backups:
files_to_remove = backup_files[:-max_backups]
for file_path in files_to_remove:
try:
file_path.unlink()
self.logger.info(f"Removed old backup: {file_path.name}")
except Exception as e:
self.logger.error(f"Error removing backup {file_path.name}: {e}")
async def export_to_csv(self, output_path: str):
"""Export user data to CSV format."""
import csv
output_path = Path(output_path)
try:
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['user_id', 'username', 'discriminator', 'display_name',
'avatar_url', 'bio', 'status', 'servers', 'created_at', 'updated_at']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for user_dict in self._data.values():
# Convert servers list to string
user_dict_copy = user_dict.copy()
user_dict_copy['servers'] = ','.join(map(str, user_dict.get('servers', [])))
writer.writerow(user_dict_copy)
self.logger.info(f"Exported {len(self._data)} users to {output_path}")
except Exception as e:
self.logger.error(f"Error exporting to CSV: {e}")
async with self.pool.cursor() as cursor:
await cursor.execute("SELECT COUNT(DISTINCT server_id) FROM user_servers")
result = await cursor.fetchone()
return result['COUNT(DISTINCT server_id)']
async def get_statistics(self) -> Dict[str, Any]:
"""Get database statistics."""
"""Get database statistics using optimized queries."""
stats = {
'total_users': await self.get_user_count(),
'total_servers': await self.get_server_count(),
'database_size': self.database_path.stat().st_size if self.database_path.exists() else 0
}
# Most active servers
server_counts = {}
for user_dict in self._data.values():
for server_id in user_dict.get('servers', []):
server_counts[server_id] = server_counts.get(server_id, 0) + 1
stats['most_active_servers'] = sorted(server_counts.items(),
key=lambda x: x[1], reverse=True)[:10]
return stats
async with self.pool.cursor() as cursor:
await cursor.execute("""
SELECT server_id, COUNT(user_id) as user_count
FROM user_servers
GROUP BY server_id
ORDER BY user_count DESC
LIMIT 10
""")
stats['most_active_servers'] = await cursor.fetchall()
return stats
async def close(self):
"""Close database connection."""
if self.pool:
await self.pool.close()
self.logger.info("Database connection closed")