i have no sweet clue if this works
This commit is contained in:
parent
a9bcce85d6
commit
5dbc12e943
379
src/database.py
379
src/database.py
|
@ -1,15 +1,14 @@
|
|||
"""
|
||||
JSON database manager for Discord user data storage.
|
||||
MariaDB database manager for Discord user data storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from asyncmy import connect, Connection, Cursor
|
||||
from asyncmy.errors import MySQLError
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -25,181 +24,257 @@ class UserData:
|
|||
status: Optional[str] = None
|
||||
activity: Optional[str] = None
|
||||
servers: List[int] = None
|
||||
created_at: str = None
|
||||
updated_at: str = None
|
||||
|
||||
created_at: datetime = None
|
||||
updated_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.servers is None:
|
||||
self.servers = []
|
||||
|
||||
current_time = datetime.utcnow().isoformat()
|
||||
current_time = datetime.utcnow()
|
||||
if self.created_at is None:
|
||||
self.created_at = current_time
|
||||
self.updated_at = current_time
|
||||
|
||||
|
||||
class JSONDatabase:
|
||||
"""JSON-based database for storing Discord user data."""
|
||||
class MariaDBDatabase:
|
||||
"""MariaDB-based database for storing Discord user data."""
|
||||
|
||||
def __init__(self, database_path: str):
|
||||
"""Initialize the JSON database."""
|
||||
self.database_path = Path(database_path)
|
||||
self.backup_path = Path("data/backups")
|
||||
def __init__(self,
|
||||
host: str,
|
||||
user: str,
|
||||
password: str,
|
||||
database: str,
|
||||
port: int = 3306):
|
||||
"""Initialize the MariaDB connection."""
|
||||
self.db_config = {
|
||||
'host': host,
|
||||
'port': port,
|
||||
'user': user,
|
||||
'password': password,
|
||||
'db': database,
|
||||
}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.pool = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._data: Dict[str, Dict] = {}
|
||||
|
||||
# Ensure database directory exists
|
||||
self.database_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.backup_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing data
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self):
|
||||
"""Load data from JSON file."""
|
||||
if self.database_path.exists():
|
||||
try:
|
||||
with open(self.database_path, 'r', encoding='utf-8') as f:
|
||||
self._data = json.load(f)
|
||||
self.logger.info(f"Loaded {len(self._data)} users from database")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading database: {e}")
|
||||
self._data = {}
|
||||
else:
|
||||
self._data = {}
|
||||
self.logger.info("Created new database")
|
||||
|
||||
async def _save_data(self):
|
||||
"""Save data to JSON file."""
|
||||
# Table schema versions
|
||||
self.schema_version = 1
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize database connection and ensure tables exist."""
|
||||
try:
|
||||
self.pool = await connect(**self.db_config)
|
||||
await self._create_tables()
|
||||
self.logger.info("Database connection established")
|
||||
except MySQLError as e:
|
||||
self.logger.error(f"Database connection failed: {e}")
|
||||
raise
|
||||
|
||||
async def _create_tables(self):
|
||||
"""Create necessary tables if they don't exist."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id BIGINT PRIMARY KEY,
|
||||
username VARCHAR(32) NOT NULL,
|
||||
discriminator VARCHAR(4) NOT NULL,
|
||||
display_name VARCHAR(32),
|
||||
avatar_url VARCHAR(255),
|
||||
banner_url VARCHAR(255),
|
||||
bio TEXT,
|
||||
status VARCHAR(20),
|
||||
activity VARCHAR(50),
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS user_servers (
|
||||
user_id BIGINT,
|
||||
server_id BIGINT,
|
||||
PRIMARY KEY (user_id, server_id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INT PRIMARY KEY,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Check schema version
|
||||
await cursor.execute("""
|
||||
INSERT IGNORE INTO schema_version (version) VALUES (%s)
|
||||
""", (self.schema_version,))
|
||||
|
||||
self.logger.info("Database tables verified/created")
|
||||
|
||||
async def _execute_query(self, query: str, params: tuple = None):
|
||||
"""Execute a database query with error handling."""
|
||||
async with self._lock:
|
||||
try:
|
||||
# Create backup before saving
|
||||
if self.database_path.exists():
|
||||
backup_filename = f"users_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
backup_path = self.backup_path / backup_filename
|
||||
shutil.copy2(self.database_path, backup_path)
|
||||
|
||||
# Save data
|
||||
with open(self.database_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self._data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
self.logger.debug(f"Saved {len(self._data)} users to database")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error saving database: {e}")
|
||||
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute(query, params)
|
||||
return cursor
|
||||
except MySQLError as e:
|
||||
self.logger.error(f"Database error: {e}")
|
||||
raise
|
||||
|
||||
async def get_user(self, user_id: int) -> Optional[UserData]:
|
||||
"""Get user data by ID."""
|
||||
user_key = str(user_id)
|
||||
if user_key in self._data:
|
||||
user_dict = self._data[user_key]
|
||||
return UserData(**user_dict)
|
||||
return None
|
||||
|
||||
"""Get user data by ID with associated servers."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
WHERE u.user_id = %s
|
||||
GROUP BY u.user_id
|
||||
""", (user_id,))
|
||||
|
||||
result = await cursor.fetchone()
|
||||
if result:
|
||||
return self._parse_user_result(result)
|
||||
return None
|
||||
|
||||
def _parse_user_result(self, result: Dict) -> UserData:
|
||||
"""Convert database result to UserData object."""
|
||||
servers = list(map(int, result['servers'].split(','))) if result['servers'] else []
|
||||
return UserData(
|
||||
user_id=result['user_id'],
|
||||
username=result['username'],
|
||||
discriminator=result['discriminator'],
|
||||
display_name=result['display_name'],
|
||||
avatar_url=result['avatar_url'],
|
||||
banner_url=result['banner_url'],
|
||||
bio=result['bio'],
|
||||
status=result['status'],
|
||||
activity=result['activity'],
|
||||
servers=servers,
|
||||
created_at=result['created_at'],
|
||||
updated_at=result['updated_at']
|
||||
)
|
||||
|
||||
async def save_user(self, user_data: UserData):
|
||||
"""Save or update user data."""
|
||||
user_key = str(user_data.user_id)
|
||||
|
||||
# If user exists, preserve created_at timestamp
|
||||
if user_key in self._data:
|
||||
user_data.created_at = self._data[user_key]['created_at']
|
||||
|
||||
# Update timestamp
|
||||
user_data.updated_at = datetime.utcnow().isoformat()
|
||||
|
||||
# Save to memory
|
||||
self._data[user_key] = asdict(user_data)
|
||||
|
||||
# Save to disk
|
||||
await self._save_data()
|
||||
|
||||
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator} ({user_data.user_id})")
|
||||
|
||||
"""Save or update user data with transaction."""
|
||||
async with self.pool.begin() as conn:
|
||||
try:
|
||||
# Upsert user data
|
||||
await conn.execute("""
|
||||
INSERT INTO users (
|
||||
user_id, username, discriminator, display_name,
|
||||
avatar_url, banner_url, bio, status, activity
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
username = VALUES(username),
|
||||
discriminator = VALUES(discriminator),
|
||||
display_name = VALUES(display_name),
|
||||
avatar_url = VALUES(avatar_url),
|
||||
banner_url = VALUES(banner_url),
|
||||
bio = VALUES(bio),
|
||||
status = VALUES(status),
|
||||
activity = VALUES(activity)
|
||||
""", (
|
||||
user_data.user_id,
|
||||
user_data.username,
|
||||
user_data.discriminator,
|
||||
user_data.display_name,
|
||||
user_data.avatar_url,
|
||||
user_data.banner_url,
|
||||
user_data.bio,
|
||||
user_data.status,
|
||||
user_data.activity
|
||||
))
|
||||
|
||||
# Update servers relationship
|
||||
await conn.execute(
|
||||
"DELETE FROM user_servers WHERE user_id = %s",
|
||||
(user_data.user_id,)
|
||||
)
|
||||
|
||||
if user_data.servers:
|
||||
server_values = [(user_data.user_id, s) for s in user_data.servers]
|
||||
await conn.executemany(
|
||||
"INSERT INTO user_servers (user_id, server_id) VALUES (%s, %s)",
|
||||
server_values
|
||||
)
|
||||
|
||||
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}")
|
||||
|
||||
except MySQLError as e:
|
||||
self.logger.error(f"Error saving user: {e}")
|
||||
await conn.rollback()
|
||||
raise
|
||||
|
||||
async def add_server_to_user(self, user_id: int, server_id: int):
|
||||
"""Add a server to user's server list."""
|
||||
user_key = str(user_id)
|
||||
if user_key in self._data:
|
||||
if server_id not in self._data[user_key]['servers']:
|
||||
self._data[user_key]['servers'].append(server_id)
|
||||
self._data[user_key]['updated_at'] = datetime.utcnow().isoformat()
|
||||
await self._save_data()
|
||||
|
||||
await self._execute_query("""
|
||||
INSERT IGNORE INTO user_servers (user_id, server_id)
|
||||
VALUES (%s, %s)
|
||||
""", (user_id, server_id))
|
||||
|
||||
async def get_all_users(self) -> List[UserData]:
|
||||
"""Get all users from the database."""
|
||||
return [UserData(**user_dict) for user_dict in self._data.values()]
|
||||
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
GROUP BY u.user_id
|
||||
""")
|
||||
results = await cursor.fetchall()
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
|
||||
async def get_users_by_server(self, server_id: int) -> List[UserData]:
|
||||
"""Get all users that are members of a specific server."""
|
||||
users = []
|
||||
for user_dict in self._data.values():
|
||||
if server_id in user_dict.get('servers', []):
|
||||
users.append(UserData(**user_dict))
|
||||
return users
|
||||
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
JOIN user_servers us ON u.user_id = us.user_id
|
||||
WHERE us.server_id = %s
|
||||
GROUP BY u.user_id
|
||||
""", (server_id,))
|
||||
results = await cursor.fetchall()
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
|
||||
async def get_user_count(self) -> int:
|
||||
"""Get total number of users in database."""
|
||||
return len(self._data)
|
||||
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) FROM users")
|
||||
result = await cursor.fetchone()
|
||||
return result['COUNT(*)']
|
||||
|
||||
async def get_server_count(self) -> int:
|
||||
"""Get total number of unique servers."""
|
||||
servers = set()
|
||||
for user_dict in self._data.values():
|
||||
servers.update(user_dict.get('servers', []))
|
||||
return len(servers)
|
||||
|
||||
async def cleanup_old_backups(self, max_backups: int = 10):
|
||||
"""Clean up old backup files, keeping only the most recent ones."""
|
||||
backup_files = sorted(self.backup_path.glob("users_backup_*.json"))
|
||||
|
||||
if len(backup_files) > max_backups:
|
||||
files_to_remove = backup_files[:-max_backups]
|
||||
for file_path in files_to_remove:
|
||||
try:
|
||||
file_path.unlink()
|
||||
self.logger.info(f"Removed old backup: {file_path.name}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error removing backup {file_path.name}: {e}")
|
||||
|
||||
async def export_to_csv(self, output_path: str):
|
||||
"""Export user data to CSV format."""
|
||||
import csv
|
||||
|
||||
output_path = Path(output_path)
|
||||
|
||||
try:
|
||||
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['user_id', 'username', 'discriminator', 'display_name',
|
||||
'avatar_url', 'bio', 'status', 'servers', 'created_at', 'updated_at']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
for user_dict in self._data.values():
|
||||
# Convert servers list to string
|
||||
user_dict_copy = user_dict.copy()
|
||||
user_dict_copy['servers'] = ','.join(map(str, user_dict.get('servers', [])))
|
||||
writer.writerow(user_dict_copy)
|
||||
|
||||
self.logger.info(f"Exported {len(self._data)} users to {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error exporting to CSV: {e}")
|
||||
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(DISTINCT server_id) FROM user_servers")
|
||||
result = await cursor.fetchone()
|
||||
return result['COUNT(DISTINCT server_id)']
|
||||
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get database statistics."""
|
||||
"""Get database statistics using optimized queries."""
|
||||
stats = {
|
||||
'total_users': await self.get_user_count(),
|
||||
'total_servers': await self.get_server_count(),
|
||||
'database_size': self.database_path.stat().st_size if self.database_path.exists() else 0
|
||||
}
|
||||
|
||||
# Most active servers
|
||||
server_counts = {}
|
||||
for user_dict in self._data.values():
|
||||
for server_id in user_dict.get('servers', []):
|
||||
server_counts[server_id] = server_counts.get(server_id, 0) + 1
|
||||
|
||||
stats['most_active_servers'] = sorted(server_counts.items(),
|
||||
key=lambda x: x[1], reverse=True)[:10]
|
||||
|
||||
return stats
|
||||
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT server_id, COUNT(user_id) as user_count
|
||||
FROM user_servers
|
||||
GROUP BY server_id
|
||||
ORDER BY user_count DESC
|
||||
LIMIT 10
|
||||
""")
|
||||
stats['most_active_servers'] = await cursor.fetchall()
|
||||
|
||||
return stats
|
||||
|
||||
async def close(self):
|
||||
"""Close database connection."""
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
self.logger.info("Database connection closed")
|
||||
|
|
Loading…
Reference in a new issue