why is there optimization in my racism app (d*scord)
This commit is contained in:
parent
10f66b95fd
commit
3131b0c839
139
DATABASE_OPTIMIZATION.md
Normal file
139
DATABASE_OPTIMIZATION.md
Normal file
|
@ -0,0 +1,139 @@
|
|||
# Database Optimization for 50k+ Users
|
||||
|
||||
## 1. Database Schema Improvements
|
||||
|
||||
### Add Indexes
|
||||
```sql
|
||||
-- Primary performance indexes
|
||||
CREATE INDEX idx_users_username ON users(username);
|
||||
CREATE INDEX idx_users_display_name ON users(display_name);
|
||||
CREATE INDEX idx_users_updated_at ON users(updated_at);
|
||||
CREATE INDEX idx_user_servers_server_id ON user_servers(server_id);
|
||||
|
||||
-- Composite indexes for common queries
|
||||
CREATE INDEX idx_users_username_display ON users(username, display_name);
|
||||
```
|
||||
|
||||
### Query Optimization
|
||||
```sql
|
||||
-- Instead of GROUP_CONCAT, use separate queries
|
||||
SELECT * FROM users WHERE user_id = ?;
|
||||
SELECT server_id FROM user_servers WHERE user_id = ?;
|
||||
```
|
||||
|
||||
## 2. Connection Pool Implementation
|
||||
|
||||
Replace single connection with proper pooling:
|
||||
|
||||
```python
|
||||
import aiomysql
|
||||
|
||||
async def create_pool():
|
||||
return await aiomysql.create_pool(
|
||||
host='localhost',
|
||||
port=3306,
|
||||
user='user',
|
||||
password='password',
|
||||
db='database',
|
||||
minsize=5,
|
||||
maxsize=20,
|
||||
charset='utf8mb4'
|
||||
)
|
||||
```
|
||||
|
||||
## 3. Pagination Implementation
|
||||
|
||||
### Database Layer
|
||||
```python
|
||||
async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT * FROM users
|
||||
ORDER BY user_id
|
||||
LIMIT %s OFFSET %s
|
||||
""", (limit, offset))
|
||||
```
|
||||
|
||||
### CLI Layer
|
||||
```python
|
||||
async def search_user_paginated(query: str, page: int = 1, per_page: int = 50):
|
||||
offset = (page - 1) * per_page
|
||||
users = await database.search_users(query, offset, per_page)
|
||||
# Display with pagination controls
|
||||
```
|
||||
|
||||
## 4. Search Optimization
|
||||
|
||||
### Full-Text Search
|
||||
```sql
|
||||
-- Add full-text index for better search
|
||||
ALTER TABLE users ADD FULLTEXT(username, display_name, bio);
|
||||
|
||||
-- Use full-text search instead of LIKE
|
||||
SELECT * FROM users
|
||||
WHERE MATCH(username, display_name) AGAINST(? IN BOOLEAN MODE);
|
||||
```
|
||||
|
||||
### Cached Search Results
|
||||
```python
|
||||
# Cache frequent searches
|
||||
from functools import lru_cache
|
||||
import asyncio
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
async def cached_user_search(query: str):
|
||||
return await database.search_users(query)
|
||||
```
|
||||
|
||||
## 5. Batch Operations
|
||||
|
||||
### Bulk Inserts
|
||||
```python
|
||||
async def save_users_batch(self, users: List[UserData]):
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# Use executemany for bulk operations
|
||||
await cursor.executemany("""
|
||||
INSERT INTO users (...) VALUES (...)
|
||||
ON DUPLICATE KEY UPDATE ...
|
||||
""", [(user.user_id, user.username, ...) for user in users])
|
||||
```
|
||||
|
||||
## 6. Rate Limiting Improvements
|
||||
|
||||
### Smarter Rate Limiting
|
||||
```python
|
||||
class AdaptiveRateLimiter:
|
||||
def __init__(self):
|
||||
self.base_delay = 1.0
|
||||
self.consecutive_429s = 0
|
||||
|
||||
async def wait(self):
|
||||
if self.consecutive_429s > 0:
|
||||
delay = self.base_delay * (2 ** self.consecutive_429s)
|
||||
await asyncio.sleep(min(delay, 60)) # Cap at 60 seconds
|
||||
else:
|
||||
await asyncio.sleep(self.base_delay)
|
||||
```
|
||||
|
||||
## Performance Estimates (50k users):
|
||||
|
||||
### Current Implementation:
|
||||
- **get_all_users()**: ~30-60 seconds + 2-4GB RAM
|
||||
- **CLI search**: ~10-30 seconds per search
|
||||
- **Database saves**: ~5-10x slower due to locking
|
||||
|
||||
### Optimized Implementation:
|
||||
- **Paginated queries**: ~0.1-0.5 seconds per page
|
||||
- **Indexed search**: ~0.1-1 second per search
|
||||
- **Connection pool**: ~2-3x faster concurrent operations
|
||||
- **Memory usage**: ~50-100MB instead of GB
|
||||
|
||||
## Implementation Priority:
|
||||
|
||||
1. **High Priority**: Connection pooling, pagination, indexes
|
||||
2. **Medium Priority**: Search optimization, batch operations
|
||||
3. **Low Priority**: Caching, adaptive rate limiting
|
||||
|
||||
The optimized version should handle 50k users with reasonable performance (~10x slower than 500 users instead of 100x slower).
|
32
cli.py
32
cli.py
|
@ -52,27 +52,22 @@ async def show_stats():
|
|||
print(f" Server {server_id}: {user_count} users")
|
||||
|
||||
|
||||
async def search_user(query: str):
|
||||
"""Search for users."""
|
||||
async def search_user(query: str, page: int = 1, per_page: int = 10):
|
||||
"""Search for users with pagination."""
|
||||
config = Config()
|
||||
database = await create_database(mariadb_config=config.get_mariadb_config())
|
||||
|
||||
all_users = await database.get_all_users()
|
||||
|
||||
# Search by username or user ID
|
||||
results = []
|
||||
for user in all_users:
|
||||
if (query.lower() in user.username.lower() or
|
||||
query.lower() in (user.display_name or "").lower() or
|
||||
query == str(user.user_id)):
|
||||
results.append(user)
|
||||
offset = (page - 1) * per_page
|
||||
results = await database.search_users(query, offset, per_page)
|
||||
|
||||
if not results:
|
||||
print("No users found matching the query.")
|
||||
return
|
||||
|
||||
print(f"\n=== Found {len(results)} users ===")
|
||||
for user in results[:10]: # Show first 10 results
|
||||
total_count = await database.get_user_count_total()
|
||||
total_pages = (total_count + per_page - 1) // per_page
|
||||
print(f"\n=== Search Results (Page {page} of {total_pages}) ===")
|
||||
for user in results:
|
||||
print(f"{user.username}#{user.discriminator} (ID: {user.user_id})")
|
||||
if user.display_name:
|
||||
print(f" Display name: {user.display_name}")
|
||||
|
@ -98,6 +93,13 @@ async def search_user(query: str):
|
|||
|
||||
print(f" Last updated: {user.updated_at}")
|
||||
print()
|
||||
|
||||
# Add pagination navigation info
|
||||
print(f"\nShowing {len(results)} of {total_count} total users")
|
||||
if total_pages > 1:
|
||||
print(f"Use --page {page + 1} for next page" if page < total_pages else "Last page")
|
||||
|
||||
await database.close()
|
||||
|
||||
|
||||
async def list_servers():
|
||||
|
@ -292,6 +294,8 @@ def main():
|
|||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search for users")
|
||||
search_parser.add_argument("query", help="Search query (username or user ID)")
|
||||
search_parser.add_argument("--page", type=int, default=1, help="Page number (default: 1)")
|
||||
search_parser.add_argument("--per-page", type=int, default=10, help="Results per page (default: 10)")
|
||||
|
||||
# Server commands
|
||||
servers_parser = subparsers.add_parser("servers", help="List all servers with user counts")
|
||||
|
@ -323,7 +327,7 @@ def main():
|
|||
elif args.command == "stats":
|
||||
asyncio.run(show_stats())
|
||||
elif args.command == "search":
|
||||
asyncio.run(search_user(args.query))
|
||||
asyncio.run(search_user(args.query, args.page, getattr(args, 'per_page', 10)))
|
||||
elif args.command == "servers":
|
||||
asyncio.run(list_servers())
|
||||
elif args.command == "user-servers":
|
||||
|
|
|
@ -15,7 +15,7 @@ except ImportError:
|
|||
|
||||
from .config import Config
|
||||
from .database import UserData
|
||||
from .rate_limiter import RateLimiter
|
||||
from .rate_limiter import RateLimiter, AdaptiveRateLimiter
|
||||
|
||||
|
||||
class DiscordDataClient(discord.Client):
|
||||
|
@ -26,9 +26,9 @@ class DiscordDataClient(discord.Client):
|
|||
|
||||
self.config = config
|
||||
self.database = database
|
||||
self.rate_limiter = RateLimiter(
|
||||
requests_per_minute=config.max_requests_per_minute,
|
||||
delay_between_requests=config.request_delay
|
||||
self.rate_limiter = AdaptiveRateLimiter(
|
||||
base_delay=config.request_delay,
|
||||
max_delay=60.0
|
||||
)
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
@ -251,7 +251,16 @@ class DiscordDataClient(discord.Client):
|
|||
await self.rate_limiter.wait()
|
||||
|
||||
# Use fetch_user_profile to get mutual guilds
|
||||
profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True)
|
||||
try:
|
||||
profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True)
|
||||
self.rate_limiter.on_success()
|
||||
except discord.HTTPException as e:
|
||||
if e.status == 429: # Rate limited
|
||||
retry_after = getattr(e, 'retry_after', None)
|
||||
self.rate_limiter.on_rate_limit(retry_after)
|
||||
raise
|
||||
else:
|
||||
raise
|
||||
|
||||
if hasattr(profile, 'mutual_guilds') and profile.mutual_guilds:
|
||||
mutual_guild_ids = []
|
||||
|
|
585
src/database.py
585
src/database.py
|
@ -15,7 +15,7 @@ import logging
|
|||
|
||||
# Optional MariaDB support
|
||||
try:
|
||||
from asyncmy import connect
|
||||
from asyncmy import connect, create_pool
|
||||
from asyncmy.cursors import DictCursor
|
||||
from asyncmy.errors import MySQLError
|
||||
MARIADB_AVAILABLE = True
|
||||
|
@ -255,6 +255,44 @@ class JSONDatabase:
|
|||
result[server_id] = servers[str(server_id)]
|
||||
return result
|
||||
|
||||
async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||
"""Get users with pagination (JSON implementation)."""
|
||||
async with self._lock:
|
||||
data = self._load_data()
|
||||
users = []
|
||||
for user_id, user_data in data.items():
|
||||
if user_id != "servers": # Skip servers data
|
||||
users.append(UserData.from_dict(user_data))
|
||||
|
||||
# Sort by user_id and paginate
|
||||
users.sort(key=lambda u: u.user_id)
|
||||
return users[offset:offset + limit]
|
||||
|
||||
async def search_users(self, query: str, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||
"""Search users with pagination (JSON implementation)."""
|
||||
async with self._lock:
|
||||
data = self._load_data()
|
||||
matching_users = []
|
||||
|
||||
for user_id, user_data in data.items():
|
||||
if user_id != "servers": # Skip servers data
|
||||
user = UserData.from_dict(user_data)
|
||||
if (query.lower() in user.username.lower() or
|
||||
query.lower() in (user.display_name or "").lower() or
|
||||
query == str(user.user_id)):
|
||||
matching_users.append(user)
|
||||
|
||||
# Sort by user_id and paginate
|
||||
matching_users.sort(key=lambda u: u.user_id)
|
||||
return matching_users[offset:offset + limit]
|
||||
|
||||
async def get_user_count_total(self) -> int:
|
||||
"""Get total number of users (JSON implementation)."""
|
||||
async with self._lock:
|
||||
data = self._load_data()
|
||||
# Count all keys except "servers"
|
||||
return len([k for k in data.keys() if k != "servers"])
|
||||
|
||||
async def close(self):
|
||||
"""Close database (no-op for JSON database)."""
|
||||
self.logger.info("JSON database closed")
|
||||
|
@ -284,101 +322,132 @@ class MariaDBDatabase:
|
|||
}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.pool = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Table schema versions
|
||||
self.schema_version = 1
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize database connection and ensure tables exist."""
|
||||
"""Initialize database connection pool and ensure tables exist."""
|
||||
try:
|
||||
# Add DictCursor to config for dictionary results
|
||||
self.db_config['cursor_cls'] = DictCursor
|
||||
self.pool = await connect(**self.db_config)
|
||||
# Create connection pool instead of single connection
|
||||
self.pool = await create_pool(
|
||||
host=self.db_config['host'],
|
||||
port=self.db_config['port'],
|
||||
user=self.db_config['user'],
|
||||
password=self.db_config['password'],
|
||||
db=self.db_config['db'],
|
||||
minsize=5,
|
||||
maxsize=20,
|
||||
charset='utf8mb4',
|
||||
cursorclass=DictCursor
|
||||
)
|
||||
await self._create_tables()
|
||||
self.logger.info("Database connection established")
|
||||
await self._create_indexes()
|
||||
self.logger.info("Database connection pool established")
|
||||
except MySQLError as e:
|
||||
self.logger.error(f"Database connection failed: {e}")
|
||||
raise
|
||||
|
||||
async def _create_tables(self):
|
||||
"""Create necessary tables if they don't exist."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id BIGINT PRIMARY KEY,
|
||||
username VARCHAR(32) NOT NULL,
|
||||
discriminator VARCHAR(4) NOT NULL,
|
||||
display_name VARCHAR(32),
|
||||
avatar_url VARCHAR(255),
|
||||
banner_url VARCHAR(255),
|
||||
bio TEXT,
|
||||
status VARCHAR(20),
|
||||
activity VARCHAR(50),
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id BIGINT PRIMARY KEY,
|
||||
username VARCHAR(32) NOT NULL,
|
||||
discriminator VARCHAR(4) NOT NULL,
|
||||
display_name VARCHAR(32),
|
||||
avatar_url VARCHAR(255),
|
||||
banner_url VARCHAR(255),
|
||||
bio TEXT,
|
||||
status VARCHAR(20),
|
||||
activity VARCHAR(50),
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS servers (
|
||||
server_id BIGINT PRIMARY KEY,
|
||||
server_name VARCHAR(100) NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS user_servers (
|
||||
user_id BIGINT,
|
||||
server_id BIGINT,
|
||||
PRIMARY KEY (user_id, server_id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (server_id) REFERENCES servers(server_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INT PRIMARY KEY,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS servers (
|
||||
server_id BIGINT PRIMARY KEY,
|
||||
server_name VARCHAR(100) NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS user_servers (
|
||||
user_id BIGINT,
|
||||
server_id BIGINT,
|
||||
PRIMARY KEY (user_id, server_id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (server_id) REFERENCES servers(server_id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INT PRIMARY KEY,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Check schema version
|
||||
await cursor.execute("""
|
||||
INSERT IGNORE INTO schema_version (version) VALUES (%s)
|
||||
""", (self.schema_version,))
|
||||
# Check schema version
|
||||
await cursor.execute("""
|
||||
INSERT IGNORE INTO schema_version (version) VALUES (%s)
|
||||
""", (self.schema_version,))
|
||||
|
||||
async def _create_indexes(self):
|
||||
"""Create performance indexes."""
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# Create indexes for better performance
|
||||
indexes = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_users_display_name ON users(display_name)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_users_updated_at ON users(updated_at)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_user_servers_server_id ON user_servers(server_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_users_username_display ON users(username, display_name)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_servers_name ON servers(server_name)",
|
||||
]
|
||||
|
||||
for index_sql in indexes:
|
||||
try:
|
||||
await cursor.execute(index_sql)
|
||||
self.logger.debug(f"Created index: {index_sql}")
|
||||
except MySQLError as e:
|
||||
if "Duplicate key name" not in str(e):
|
||||
self.logger.warning(f"Failed to create index: {e}")
|
||||
|
||||
# Add full-text search index
|
||||
try:
|
||||
await cursor.execute("""
|
||||
ALTER TABLE users ADD FULLTEXT(username, display_name, bio)
|
||||
""")
|
||||
except MySQLError as e:
|
||||
if "Duplicate key name" not in str(e):
|
||||
self.logger.debug(f"Full-text index already exists or failed: {e}")
|
||||
|
||||
self.logger.info("Database tables verified/created")
|
||||
|
||||
async def _execute_query(self, query: str, params: tuple = None):
|
||||
"""Execute a database query with error handling."""
|
||||
async with self._lock:
|
||||
try:
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute(query, params)
|
||||
return cursor
|
||||
except MySQLError as e:
|
||||
self.logger.error(f"Database error: {e}")
|
||||
raise
|
||||
|
||||
async def get_user(self, user_id: int) -> Optional[UserData]:
|
||||
"""Get user data by ID with associated servers."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
WHERE u.user_id = %s
|
||||
GROUP BY u.user_id
|
||||
""", (user_id,))
|
||||
|
||||
result = await cursor.fetchone()
|
||||
if result:
|
||||
return self._parse_user_result(result)
|
||||
return None
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
WHERE u.user_id = %s
|
||||
GROUP BY u.user_id
|
||||
""", (user_id,))
|
||||
|
||||
result = await cursor.fetchone()
|
||||
if result:
|
||||
return self._parse_user_result(result)
|
||||
return None
|
||||
|
||||
def _parse_user_result(self, result: Dict) -> UserData:
|
||||
"""Convert database result to UserData object."""
|
||||
|
@ -400,64 +469,65 @@ class MariaDBDatabase:
|
|||
|
||||
async def save_user(self, user_data: UserData):
|
||||
"""Save or update user data with transaction."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
try:
|
||||
# Start transaction
|
||||
await cursor.execute("START TRANSACTION")
|
||||
|
||||
# Upsert user data
|
||||
await cursor.execute("""
|
||||
INSERT INTO users (
|
||||
user_id, username, discriminator, display_name,
|
||||
avatar_url, banner_url, bio, status, activity
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
username = VALUES(username),
|
||||
discriminator = VALUES(discriminator),
|
||||
display_name = VALUES(display_name),
|
||||
avatar_url = VALUES(avatar_url),
|
||||
banner_url = VALUES(banner_url),
|
||||
bio = VALUES(bio),
|
||||
status = VALUES(status),
|
||||
activity = VALUES(activity)
|
||||
""", (
|
||||
user_data.user_id,
|
||||
user_data.username,
|
||||
user_data.discriminator,
|
||||
user_data.display_name,
|
||||
user_data.avatar_url,
|
||||
user_data.banner_url,
|
||||
user_data.bio,
|
||||
user_data.status,
|
||||
user_data.activity
|
||||
))
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
try:
|
||||
# Start transaction
|
||||
await cursor.execute("START TRANSACTION")
|
||||
|
||||
# Upsert user data
|
||||
await cursor.execute("""
|
||||
INSERT INTO users (
|
||||
user_id, username, discriminator, display_name,
|
||||
avatar_url, banner_url, bio, status, activity
|
||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
username = VALUES(username),
|
||||
discriminator = VALUES(discriminator),
|
||||
display_name = VALUES(display_name),
|
||||
avatar_url = VALUES(avatar_url),
|
||||
banner_url = VALUES(banner_url),
|
||||
bio = VALUES(bio),
|
||||
status = VALUES(status),
|
||||
activity = VALUES(activity)
|
||||
""", (
|
||||
user_data.user_id,
|
||||
user_data.username,
|
||||
user_data.discriminator,
|
||||
user_data.display_name,
|
||||
user_data.avatar_url,
|
||||
user_data.banner_url,
|
||||
user_data.bio,
|
||||
user_data.status,
|
||||
user_data.activity
|
||||
))
|
||||
|
||||
# Update servers relationship
|
||||
await cursor.execute(
|
||||
"DELETE FROM user_servers WHERE user_id = %s",
|
||||
(user_data.user_id,)
|
||||
)
|
||||
|
||||
if user_data.servers:
|
||||
for server_id in user_data.servers:
|
||||
await cursor.execute(
|
||||
"INSERT IGNORE INTO user_servers (user_id, server_id) VALUES (%s, %s)",
|
||||
(user_data.user_id, server_id)
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
await cursor.execute("COMMIT")
|
||||
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}")
|
||||
# Update servers relationship
|
||||
await cursor.execute(
|
||||
"DELETE FROM user_servers WHERE user_id = %s",
|
||||
(user_data.user_id,)
|
||||
)
|
||||
|
||||
if user_data.servers:
|
||||
for server_id in user_data.servers:
|
||||
await cursor.execute(
|
||||
"INSERT IGNORE INTO user_servers (user_id, server_id) VALUES (%s, %s)",
|
||||
(user_data.user_id, server_id)
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
await cursor.execute("COMMIT")
|
||||
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}")
|
||||
|
||||
except MySQLError as e:
|
||||
self.logger.error(f"Error saving user: {e}")
|
||||
await cursor.execute("ROLLBACK")
|
||||
raise
|
||||
except MySQLError as e:
|
||||
self.logger.error(f"Error saving user: {e}")
|
||||
await cursor.execute("ROLLBACK")
|
||||
raise
|
||||
|
||||
async def save_server(self, server_id: int, server_name: str):
|
||||
"""Save server information."""
|
||||
async with self._lock:
|
||||
async with self.pool.cursor() as cursor:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
INSERT INTO servers (server_id, server_name)
|
||||
VALUES (%s, %s)
|
||||
|
@ -471,8 +541,8 @@ class MariaDBDatabase:
|
|||
if not server_ids:
|
||||
return {}
|
||||
|
||||
async with self._lock:
|
||||
async with self.pool.cursor() as cursor:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
placeholders = ','.join(['%s'] * len(server_ids))
|
||||
await cursor.execute(f"""
|
||||
SELECT server_id, server_name
|
||||
|
@ -485,8 +555,8 @@ class MariaDBDatabase:
|
|||
|
||||
async def add_server_to_user(self, user_id: int, server_id: int):
|
||||
"""Add a server to user's server list."""
|
||||
async with self._lock:
|
||||
async with self.pool.cursor() as cursor:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
INSERT IGNORE INTO user_servers (user_id, server_id)
|
||||
VALUES (%s, %s)
|
||||
|
@ -494,42 +564,175 @@ class MariaDBDatabase:
|
|||
|
||||
async def get_all_users(self) -> List[UserData]:
|
||||
"""Get all users from the database."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
GROUP BY u.user_id
|
||||
""")
|
||||
results = await cursor.fetchall()
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
GROUP BY u.user_id
|
||||
""")
|
||||
results = await cursor.fetchall()
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
|
||||
async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||
"""Get users with pagination."""
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
GROUP BY u.user_id
|
||||
ORDER BY u.user_id
|
||||
LIMIT %s OFFSET %s
|
||||
""", (limit, offset))
|
||||
results = await cursor.fetchall()
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
|
||||
async def search_users(self, query: str, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||
"""Search users with pagination."""
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# Try full-text search first
|
||||
try:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
WHERE MATCH(u.username, u.display_name, u.bio) AGAINST(%s IN BOOLEAN MODE)
|
||||
GROUP BY u.user_id
|
||||
ORDER BY u.user_id
|
||||
LIMIT %s OFFSET %s
|
||||
""", (f"*{query}*", limit, offset))
|
||||
results = await cursor.fetchall()
|
||||
if results:
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
except MySQLError:
|
||||
# Fall back to LIKE search if full-text fails
|
||||
pass
|
||||
|
||||
# Fallback to LIKE search
|
||||
search_pattern = f"%{query}%"
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||
WHERE u.username LIKE %s OR u.display_name LIKE %s OR u.user_id = %s
|
||||
GROUP BY u.user_id
|
||||
ORDER BY u.user_id
|
||||
LIMIT %s OFFSET %s
|
||||
""", (search_pattern, search_pattern, query if query.isdigit() else None, limit, offset))
|
||||
results = await cursor.fetchall()
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
|
||||
async def get_user_count_total(self) -> int:
|
||||
"""Get total number of users."""
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) as count FROM users")
|
||||
result = await cursor.fetchone()
|
||||
return result['count'] if result else 0
|
||||
|
||||
async def save_users_batch(self, users: List[UserData]):
|
||||
"""Save multiple users in a batch operation."""
|
||||
if not users:
|
||||
return
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
try:
|
||||
await cursor.execute("START TRANSACTION")
|
||||
|
||||
# Prepare batch data for users
|
||||
user_data = []
|
||||
server_data = []
|
||||
|
||||
for user in users:
|
||||
user_data.append((
|
||||
user.user_id,
|
||||
user.username,
|
||||
user.discriminator,
|
||||
user.display_name,
|
||||
user.avatar_url,
|
||||
user.banner_url,
|
||||
user.bio,
|
||||
user.status,
|
||||
user.activity
|
||||
))
|
||||
|
||||
# Collect server relationships
|
||||
for server_id in user.servers:
|
||||
server_data.append((user.user_id, server_id))
|
||||
|
||||
# Batch insert users
|
||||
await cursor.executemany("""
|
||||
INSERT INTO users (user_id, username, discriminator, display_name,
|
||||
avatar_url, banner_url, bio, status, activity)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
username = VALUES(username),
|
||||
discriminator = VALUES(discriminator),
|
||||
display_name = VALUES(display_name),
|
||||
avatar_url = VALUES(avatar_url),
|
||||
banner_url = VALUES(banner_url),
|
||||
bio = VALUES(bio),
|
||||
status = VALUES(status),
|
||||
activity = VALUES(activity)
|
||||
""", user_data)
|
||||
|
||||
# Clear existing server relationships for these users
|
||||
user_ids = [user.user_id for user in users]
|
||||
if user_ids:
|
||||
placeholders = ','.join(['%s'] * len(user_ids))
|
||||
await cursor.execute(f"""
|
||||
DELETE FROM user_servers WHERE user_id IN ({placeholders})
|
||||
""", user_ids)
|
||||
|
||||
# Batch insert server relationships
|
||||
if server_data:
|
||||
await cursor.executemany("""
|
||||
INSERT IGNORE INTO user_servers (user_id, server_id)
|
||||
VALUES (%s, %s)
|
||||
""", server_data)
|
||||
|
||||
await cursor.execute("COMMIT")
|
||||
self.logger.info(f"Batch saved {len(users)} users")
|
||||
|
||||
except MySQLError as e:
|
||||
await cursor.execute("ROLLBACK")
|
||||
self.logger.error(f"Error in batch save: {e}")
|
||||
raise
|
||||
|
||||
async def get_users_by_server(self, server_id: int) -> List[UserData]:
|
||||
"""Get all users that are members of a specific server."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
JOIN user_servers us ON u.user_id = us.user_id
|
||||
WHERE us.server_id = %s
|
||||
GROUP BY u.user_id
|
||||
""", (server_id,))
|
||||
results = await cursor.fetchall()
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||
FROM users u
|
||||
JOIN user_servers us ON u.user_id = us.user_id
|
||||
WHERE us.server_id = %s
|
||||
GROUP BY u.user_id
|
||||
""", (server_id,))
|
||||
results = await cursor.fetchall()
|
||||
return [self._parse_user_result(r) for r in results]
|
||||
|
||||
async def get_user_count(self) -> int:
|
||||
"""Get total number of users in database."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) as user_count FROM users")
|
||||
result = await cursor.fetchone()
|
||||
return result['user_count'] if result else 0
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) as user_count FROM users")
|
||||
result = await cursor.fetchone()
|
||||
return result['user_count'] if result else 0
|
||||
|
||||
async def get_server_count(self) -> int:
|
||||
"""Get total number of unique servers."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(DISTINCT server_id) as server_count FROM user_servers")
|
||||
result = await cursor.fetchone()
|
||||
return result['server_count'] if result else 0
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(DISTINCT server_id) as server_count FROM user_servers")
|
||||
result = await cursor.fetchone()
|
||||
return result['server_count'] if result else 0
|
||||
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get database statistics using optimized queries."""
|
||||
|
@ -538,30 +741,48 @@ class MariaDBDatabase:
|
|||
'total_servers': await self.get_server_count(),
|
||||
}
|
||||
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT server_id, COUNT(user_id) as user_count
|
||||
FROM user_servers
|
||||
GROUP BY server_id
|
||||
ORDER BY user_count DESC
|
||||
LIMIT 10
|
||||
""")
|
||||
most_active = await cursor.fetchall()
|
||||
# Convert to list of tuples for consistency with JSON version
|
||||
stats['most_active_servers'] = [(row['server_id'], row['user_count']) for row in most_active]
|
||||
|
||||
# Get database size
|
||||
await cursor.execute("""
|
||||
SELECT
|
||||
ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS database_size_mb
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
""")
|
||||
size_result = await cursor.fetchone()
|
||||
stats['database_size'] = int((size_result['database_size_mb'] or 0) * 1024 * 1024) # Convert to bytes
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
SELECT server_id, COUNT(user_id) as user_count
|
||||
FROM user_servers
|
||||
GROUP BY server_id
|
||||
ORDER BY user_count DESC
|
||||
LIMIT 10
|
||||
""")
|
||||
most_active = await cursor.fetchall()
|
||||
# Convert to list of tuples for consistency with JSON version
|
||||
stats['most_active_servers'] = [(row['server_id'], row['user_count']) for row in most_active]
|
||||
|
||||
# Get database size
|
||||
await cursor.execute("""
|
||||
SELECT
|
||||
ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS database_size_mb
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
""")
|
||||
size_result = await cursor.fetchone()
|
||||
stats['database_size'] = int((size_result['database_size_mb'] or 0) * 1024 * 1024) # Convert to bytes
|
||||
|
||||
return stats
|
||||
|
||||
async def export_to_csv(self, output_path: str):
|
||||
"""Export data to CSV format."""
|
||||
import csv
|
||||
users = await self.get_all_users()
|
||||
|
||||
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['user_id', 'username', 'discriminator', 'display_name',
|
||||
'avatar_url', 'banner_url', 'bio', 'status', 'activity',
|
||||
'servers', 'created_at', 'updated_at']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
for user in users:
|
||||
row = user.to_dict()
|
||||
row['servers'] = ','.join(map(str, user.servers))
|
||||
writer.writerow(row)
|
||||
|
||||
async def cleanup_old_backups(self, max_backups: int = 5):
|
||||
"""Clean up old backup documents (placeholder for MariaDB)."""
|
||||
# For MariaDB, we could implement this as cleaning up old backup tables
|
||||
|
|
|
@ -9,6 +9,72 @@ from typing import Optional
|
|||
import logging
|
||||
|
||||
|
||||
class AdaptiveRateLimiter:
|
||||
"""Adaptive rate limiter that adjusts to Discord API rate limits."""
|
||||
|
||||
def __init__(self, base_delay: float = 1.0, max_delay: float = 60.0):
|
||||
self.base_delay = base_delay
|
||||
self.max_delay = max_delay
|
||||
self.current_delay = base_delay
|
||||
self.consecutive_429s = 0
|
||||
self.last_request_time: Optional[float] = None
|
||||
self.success_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def wait(self):
|
||||
"""Wait with adaptive delay based on rate limit responses."""
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we need to wait based on current delay
|
||||
if self.last_request_time is not None:
|
||||
time_since_last_request = current_time - self.last_request_time
|
||||
if time_since_last_request < self.current_delay:
|
||||
wait_time = self.current_delay - time_since_last_request
|
||||
self.logger.debug(f"Adaptive rate limit wait: {wait_time:.2f}s")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
self.last_request_time = time.time()
|
||||
|
||||
def on_success(self):
|
||||
"""Called when a request succeeds."""
|
||||
self.success_count += 1
|
||||
|
||||
# After 10 successful requests, try to reduce delay
|
||||
if self.success_count >= 10:
|
||||
old_delay = self.current_delay
|
||||
self.current_delay = max(self.base_delay, self.current_delay * 0.9)
|
||||
self.consecutive_429s = 0
|
||||
self.success_count = 0
|
||||
if old_delay != self.current_delay:
|
||||
self.logger.debug(f"Reduced delay to {self.current_delay:.2f}s")
|
||||
|
||||
def on_rate_limit(self, retry_after: Optional[float] = None):
|
||||
"""Called when a 429 rate limit is encountered."""
|
||||
self.consecutive_429s += 1
|
||||
self.success_count = 0
|
||||
|
||||
if retry_after:
|
||||
# Use Discord's suggested retry time
|
||||
self.current_delay = min(self.max_delay, retry_after)
|
||||
self.logger.warning(f"Rate limited, using Discord's retry_after: {retry_after}s")
|
||||
else:
|
||||
# Exponential backoff
|
||||
old_delay = self.current_delay
|
||||
self.current_delay = min(self.max_delay, self.base_delay * (2 ** self.consecutive_429s))
|
||||
self.logger.warning(f"Rate limited, increased delay from {old_delay:.2f}s to {self.current_delay:.2f}s")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get rate limiter statistics."""
|
||||
current_time = time.time()
|
||||
return {
|
||||
'current_delay': self.current_delay,
|
||||
'base_delay': self.base_delay,
|
||||
'consecutive_429s': self.consecutive_429s,
|
||||
'success_count': self.success_count,
|
||||
'time_since_last_request': current_time - self.last_request_time if self.last_request_time else 0
|
||||
}
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter to prevent hitting Discord API limits."""
|
||||
|
||||
|
|
Loading…
Reference in a new issue