why is there optimization in my racism app (d*scord)
This commit is contained in:
parent
10f66b95fd
commit
3131b0c839
139
DATABASE_OPTIMIZATION.md
Normal file
139
DATABASE_OPTIMIZATION.md
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
# Database Optimization for 50k+ Users
|
||||||
|
|
||||||
|
## 1. Database Schema Improvements
|
||||||
|
|
||||||
|
### Add Indexes
|
||||||
|
```sql
|
||||||
|
-- Primary performance indexes
|
||||||
|
CREATE INDEX idx_users_username ON users(username);
|
||||||
|
CREATE INDEX idx_users_display_name ON users(display_name);
|
||||||
|
CREATE INDEX idx_users_updated_at ON users(updated_at);
|
||||||
|
CREATE INDEX idx_user_servers_server_id ON user_servers(server_id);
|
||||||
|
|
||||||
|
-- Composite indexes for common queries
|
||||||
|
CREATE INDEX idx_users_username_display ON users(username, display_name);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Query Optimization
|
||||||
|
```sql
|
||||||
|
-- Instead of GROUP_CONCAT, use separate queries
|
||||||
|
SELECT * FROM users WHERE user_id = ?;
|
||||||
|
SELECT server_id FROM user_servers WHERE user_id = ?;
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. Connection Pool Implementation
|
||||||
|
|
||||||
|
Replace single connection with proper pooling:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import aiomysql
|
||||||
|
|
||||||
|
async def create_pool():
|
||||||
|
return await aiomysql.create_pool(
|
||||||
|
host='localhost',
|
||||||
|
port=3306,
|
||||||
|
user='user',
|
||||||
|
password='password',
|
||||||
|
db='database',
|
||||||
|
minsize=5,
|
||||||
|
maxsize=20,
|
||||||
|
charset='utf8mb4'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. Pagination Implementation
|
||||||
|
|
||||||
|
### Database Layer
|
||||||
|
```python
|
||||||
|
async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||||
|
async with self.pool.acquire() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
await cursor.execute("""
|
||||||
|
SELECT * FROM users
|
||||||
|
ORDER BY user_id
|
||||||
|
LIMIT %s OFFSET %s
|
||||||
|
""", (limit, offset))
|
||||||
|
```
|
||||||
|
|
||||||
|
### CLI Layer
|
||||||
|
```python
|
||||||
|
async def search_user_paginated(query: str, page: int = 1, per_page: int = 50):
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
users = await database.search_users(query, offset, per_page)
|
||||||
|
# Display with pagination controls
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. Search Optimization
|
||||||
|
|
||||||
|
### Full-Text Search
|
||||||
|
```sql
|
||||||
|
-- Add full-text index for better search
|
||||||
|
ALTER TABLE users ADD FULLTEXT(username, display_name, bio);
|
||||||
|
|
||||||
|
-- Use full-text search instead of LIKE
|
||||||
|
SELECT * FROM users
|
||||||
|
WHERE MATCH(username, display_name) AGAINST(? IN BOOLEAN MODE);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cached Search Results
|
||||||
|
```python
|
||||||
|
# Cache frequent searches
|
||||||
|
from functools import lru_cache
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1000)
|
||||||
|
async def cached_user_search(query: str):
|
||||||
|
return await database.search_users(query)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. Batch Operations
|
||||||
|
|
||||||
|
### Bulk Inserts
|
||||||
|
```python
|
||||||
|
async def save_users_batch(self, users: List[UserData]):
|
||||||
|
async with self.pool.acquire() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
# Use executemany for bulk operations
|
||||||
|
await cursor.executemany("""
|
||||||
|
INSERT INTO users (...) VALUES (...)
|
||||||
|
ON DUPLICATE KEY UPDATE ...
|
||||||
|
""", [(user.user_id, user.username, ...) for user in users])
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6. Rate Limiting Improvements
|
||||||
|
|
||||||
|
### Smarter Rate Limiting
|
||||||
|
```python
|
||||||
|
class AdaptiveRateLimiter:
|
||||||
|
def __init__(self):
|
||||||
|
self.base_delay = 1.0
|
||||||
|
self.consecutive_429s = 0
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
if self.consecutive_429s > 0:
|
||||||
|
delay = self.base_delay * (2 ** self.consecutive_429s)
|
||||||
|
await asyncio.sleep(min(delay, 60)) # Cap at 60 seconds
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(self.base_delay)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Estimates (50k users):
|
||||||
|
|
||||||
|
### Current Implementation:
|
||||||
|
- **get_all_users()**: ~30-60 seconds + 2-4GB RAM
|
||||||
|
- **CLI search**: ~10-30 seconds per search
|
||||||
|
- **Database saves**: ~5-10x slower due to locking
|
||||||
|
|
||||||
|
### Optimized Implementation:
|
||||||
|
- **Paginated queries**: ~0.1-0.5 seconds per page
|
||||||
|
- **Indexed search**: ~0.1-1 second per search
|
||||||
|
- **Connection pool**: ~2-3x faster concurrent operations
|
||||||
|
- **Memory usage**: ~50-100MB instead of GB
|
||||||
|
|
||||||
|
## Implementation Priority:
|
||||||
|
|
||||||
|
1. **High Priority**: Connection pooling, pagination, indexes
|
||||||
|
2. **Medium Priority**: Search optimization, batch operations
|
||||||
|
3. **Low Priority**: Caching, adaptive rate limiting
|
||||||
|
|
||||||
|
The optimized version should handle 50k users with reasonable performance (~10x slower than 500 users instead of 100x slower).
|
32
cli.py
32
cli.py
|
@ -52,27 +52,22 @@ async def show_stats():
|
||||||
print(f" Server {server_id}: {user_count} users")
|
print(f" Server {server_id}: {user_count} users")
|
||||||
|
|
||||||
|
|
||||||
async def search_user(query: str):
|
async def search_user(query: str, page: int = 1, per_page: int = 10):
|
||||||
"""Search for users."""
|
"""Search for users with pagination."""
|
||||||
config = Config()
|
config = Config()
|
||||||
database = await create_database(mariadb_config=config.get_mariadb_config())
|
database = await create_database(mariadb_config=config.get_mariadb_config())
|
||||||
|
|
||||||
all_users = await database.get_all_users()
|
offset = (page - 1) * per_page
|
||||||
|
results = await database.search_users(query, offset, per_page)
|
||||||
# Search by username or user ID
|
|
||||||
results = []
|
|
||||||
for user in all_users:
|
|
||||||
if (query.lower() in user.username.lower() or
|
|
||||||
query.lower() in (user.display_name or "").lower() or
|
|
||||||
query == str(user.user_id)):
|
|
||||||
results.append(user)
|
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
print("No users found matching the query.")
|
print("No users found matching the query.")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"\n=== Found {len(results)} users ===")
|
total_count = await database.get_user_count_total()
|
||||||
for user in results[:10]: # Show first 10 results
|
total_pages = (total_count + per_page - 1) // per_page
|
||||||
|
print(f"\n=== Search Results (Page {page} of {total_pages}) ===")
|
||||||
|
for user in results:
|
||||||
print(f"{user.username}#{user.discriminator} (ID: {user.user_id})")
|
print(f"{user.username}#{user.discriminator} (ID: {user.user_id})")
|
||||||
if user.display_name:
|
if user.display_name:
|
||||||
print(f" Display name: {user.display_name}")
|
print(f" Display name: {user.display_name}")
|
||||||
|
@ -98,6 +93,13 @@ async def search_user(query: str):
|
||||||
|
|
||||||
print(f" Last updated: {user.updated_at}")
|
print(f" Last updated: {user.updated_at}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
# Add pagination navigation info
|
||||||
|
print(f"\nShowing {len(results)} of {total_count} total users")
|
||||||
|
if total_pages > 1:
|
||||||
|
print(f"Use --page {page + 1} for next page" if page < total_pages else "Last page")
|
||||||
|
|
||||||
|
await database.close()
|
||||||
|
|
||||||
|
|
||||||
async def list_servers():
|
async def list_servers():
|
||||||
|
@ -292,6 +294,8 @@ def main():
|
||||||
# Search command
|
# Search command
|
||||||
search_parser = subparsers.add_parser("search", help="Search for users")
|
search_parser = subparsers.add_parser("search", help="Search for users")
|
||||||
search_parser.add_argument("query", help="Search query (username or user ID)")
|
search_parser.add_argument("query", help="Search query (username or user ID)")
|
||||||
|
search_parser.add_argument("--page", type=int, default=1, help="Page number (default: 1)")
|
||||||
|
search_parser.add_argument("--per-page", type=int, default=10, help="Results per page (default: 10)")
|
||||||
|
|
||||||
# Server commands
|
# Server commands
|
||||||
servers_parser = subparsers.add_parser("servers", help="List all servers with user counts")
|
servers_parser = subparsers.add_parser("servers", help="List all servers with user counts")
|
||||||
|
@ -323,7 +327,7 @@ def main():
|
||||||
elif args.command == "stats":
|
elif args.command == "stats":
|
||||||
asyncio.run(show_stats())
|
asyncio.run(show_stats())
|
||||||
elif args.command == "search":
|
elif args.command == "search":
|
||||||
asyncio.run(search_user(args.query))
|
asyncio.run(search_user(args.query, args.page, getattr(args, 'per_page', 10)))
|
||||||
elif args.command == "servers":
|
elif args.command == "servers":
|
||||||
asyncio.run(list_servers())
|
asyncio.run(list_servers())
|
||||||
elif args.command == "user-servers":
|
elif args.command == "user-servers":
|
||||||
|
|
|
@ -15,7 +15,7 @@ except ImportError:
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .database import UserData
|
from .database import UserData
|
||||||
from .rate_limiter import RateLimiter
|
from .rate_limiter import RateLimiter, AdaptiveRateLimiter
|
||||||
|
|
||||||
|
|
||||||
class DiscordDataClient(discord.Client):
|
class DiscordDataClient(discord.Client):
|
||||||
|
@ -26,9 +26,9 @@ class DiscordDataClient(discord.Client):
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.database = database
|
self.database = database
|
||||||
self.rate_limiter = RateLimiter(
|
self.rate_limiter = AdaptiveRateLimiter(
|
||||||
requests_per_minute=config.max_requests_per_minute,
|
base_delay=config.request_delay,
|
||||||
delay_between_requests=config.request_delay
|
max_delay=60.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
@ -251,7 +251,16 @@ class DiscordDataClient(discord.Client):
|
||||||
await self.rate_limiter.wait()
|
await self.rate_limiter.wait()
|
||||||
|
|
||||||
# Use fetch_user_profile to get mutual guilds
|
# Use fetch_user_profile to get mutual guilds
|
||||||
profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True)
|
try:
|
||||||
|
profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True)
|
||||||
|
self.rate_limiter.on_success()
|
||||||
|
except discord.HTTPException as e:
|
||||||
|
if e.status == 429: # Rate limited
|
||||||
|
retry_after = getattr(e, 'retry_after', None)
|
||||||
|
self.rate_limiter.on_rate_limit(retry_after)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
if hasattr(profile, 'mutual_guilds') and profile.mutual_guilds:
|
if hasattr(profile, 'mutual_guilds') and profile.mutual_guilds:
|
||||||
mutual_guild_ids = []
|
mutual_guild_ids = []
|
||||||
|
|
585
src/database.py
585
src/database.py
|
@ -15,7 +15,7 @@ import logging
|
||||||
|
|
||||||
# Optional MariaDB support
|
# Optional MariaDB support
|
||||||
try:
|
try:
|
||||||
from asyncmy import connect
|
from asyncmy import connect, create_pool
|
||||||
from asyncmy.cursors import DictCursor
|
from asyncmy.cursors import DictCursor
|
||||||
from asyncmy.errors import MySQLError
|
from asyncmy.errors import MySQLError
|
||||||
MARIADB_AVAILABLE = True
|
MARIADB_AVAILABLE = True
|
||||||
|
@ -255,6 +255,44 @@ class JSONDatabase:
|
||||||
result[server_id] = servers[str(server_id)]
|
result[server_id] = servers[str(server_id)]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||||
|
"""Get users with pagination (JSON implementation)."""
|
||||||
|
async with self._lock:
|
||||||
|
data = self._load_data()
|
||||||
|
users = []
|
||||||
|
for user_id, user_data in data.items():
|
||||||
|
if user_id != "servers": # Skip servers data
|
||||||
|
users.append(UserData.from_dict(user_data))
|
||||||
|
|
||||||
|
# Sort by user_id and paginate
|
||||||
|
users.sort(key=lambda u: u.user_id)
|
||||||
|
return users[offset:offset + limit]
|
||||||
|
|
||||||
|
async def search_users(self, query: str, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||||
|
"""Search users with pagination (JSON implementation)."""
|
||||||
|
async with self._lock:
|
||||||
|
data = self._load_data()
|
||||||
|
matching_users = []
|
||||||
|
|
||||||
|
for user_id, user_data in data.items():
|
||||||
|
if user_id != "servers": # Skip servers data
|
||||||
|
user = UserData.from_dict(user_data)
|
||||||
|
if (query.lower() in user.username.lower() or
|
||||||
|
query.lower() in (user.display_name or "").lower() or
|
||||||
|
query == str(user.user_id)):
|
||||||
|
matching_users.append(user)
|
||||||
|
|
||||||
|
# Sort by user_id and paginate
|
||||||
|
matching_users.sort(key=lambda u: u.user_id)
|
||||||
|
return matching_users[offset:offset + limit]
|
||||||
|
|
||||||
|
async def get_user_count_total(self) -> int:
|
||||||
|
"""Get total number of users (JSON implementation)."""
|
||||||
|
async with self._lock:
|
||||||
|
data = self._load_data()
|
||||||
|
# Count all keys except "servers"
|
||||||
|
return len([k for k in data.keys() if k != "servers"])
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close database (no-op for JSON database)."""
|
"""Close database (no-op for JSON database)."""
|
||||||
self.logger.info("JSON database closed")
|
self.logger.info("JSON database closed")
|
||||||
|
@ -284,101 +322,132 @@ class MariaDBDatabase:
|
||||||
}
|
}
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.pool = None
|
self.pool = None
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
|
|
||||||
# Table schema versions
|
# Table schema versions
|
||||||
self.schema_version = 1
|
self.schema_version = 1
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize database connection and ensure tables exist."""
|
"""Initialize database connection pool and ensure tables exist."""
|
||||||
try:
|
try:
|
||||||
# Add DictCursor to config for dictionary results
|
# Create connection pool instead of single connection
|
||||||
self.db_config['cursor_cls'] = DictCursor
|
self.pool = await create_pool(
|
||||||
self.pool = await connect(**self.db_config)
|
host=self.db_config['host'],
|
||||||
|
port=self.db_config['port'],
|
||||||
|
user=self.db_config['user'],
|
||||||
|
password=self.db_config['password'],
|
||||||
|
db=self.db_config['db'],
|
||||||
|
minsize=5,
|
||||||
|
maxsize=20,
|
||||||
|
charset='utf8mb4',
|
||||||
|
cursorclass=DictCursor
|
||||||
|
)
|
||||||
await self._create_tables()
|
await self._create_tables()
|
||||||
self.logger.info("Database connection established")
|
await self._create_indexes()
|
||||||
|
self.logger.info("Database connection pool established")
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
self.logger.error(f"Database connection failed: {e}")
|
self.logger.error(f"Database connection failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _create_tables(self):
|
async def _create_tables(self):
|
||||||
"""Create necessary tables if they don't exist."""
|
"""Create necessary tables if they don't exist."""
|
||||||
async with self.pool.cursor() as cursor:
|
async with self.pool.acquire() as conn:
|
||||||
await cursor.execute("""
|
async with conn.cursor() as cursor:
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
await cursor.execute("""
|
||||||
user_id BIGINT PRIMARY KEY,
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
username VARCHAR(32) NOT NULL,
|
user_id BIGINT PRIMARY KEY,
|
||||||
discriminator VARCHAR(4) NOT NULL,
|
username VARCHAR(32) NOT NULL,
|
||||||
display_name VARCHAR(32),
|
discriminator VARCHAR(4) NOT NULL,
|
||||||
avatar_url VARCHAR(255),
|
display_name VARCHAR(32),
|
||||||
banner_url VARCHAR(255),
|
avatar_url VARCHAR(255),
|
||||||
bio TEXT,
|
banner_url VARCHAR(255),
|
||||||
status VARCHAR(20),
|
bio TEXT,
|
||||||
activity VARCHAR(50),
|
status VARCHAR(20),
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
activity VARCHAR(50),
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
)
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||||
""")
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
await cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS servers (
|
||||||
|
server_id BIGINT PRIMARY KEY,
|
||||||
|
server_name VARCHAR(100) NOT NULL,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
await cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS user_servers (
|
||||||
|
user_id BIGINT,
|
||||||
|
server_id BIGINT,
|
||||||
|
PRIMARY KEY (user_id, server_id),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (server_id) REFERENCES servers(server_id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
await cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
version INT PRIMARY KEY,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
await cursor.execute("""
|
# Check schema version
|
||||||
CREATE TABLE IF NOT EXISTS servers (
|
await cursor.execute("""
|
||||||
server_id BIGINT PRIMARY KEY,
|
INSERT IGNORE INTO schema_version (version) VALUES (%s)
|
||||||
server_name VARCHAR(100) NOT NULL,
|
""", (self.schema_version,))
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
async def _create_indexes(self):
|
||||||
)
|
"""Create performance indexes."""
|
||||||
""")
|
async with self.pool.acquire() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute("""
|
# Create indexes for better performance
|
||||||
CREATE TABLE IF NOT EXISTS user_servers (
|
indexes = [
|
||||||
user_id BIGINT,
|
"CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)",
|
||||||
server_id BIGINT,
|
"CREATE INDEX IF NOT EXISTS idx_users_display_name ON users(display_name)",
|
||||||
PRIMARY KEY (user_id, server_id),
|
"CREATE INDEX IF NOT EXISTS idx_users_updated_at ON users(updated_at)",
|
||||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE,
|
"CREATE INDEX IF NOT EXISTS idx_user_servers_server_id ON user_servers(server_id)",
|
||||||
FOREIGN KEY (server_id) REFERENCES servers(server_id) ON DELETE CASCADE
|
"CREATE INDEX IF NOT EXISTS idx_users_username_display ON users(username, display_name)",
|
||||||
)
|
"CREATE INDEX IF NOT EXISTS idx_servers_name ON servers(server_name)",
|
||||||
""")
|
]
|
||||||
|
|
||||||
await cursor.execute("""
|
for index_sql in indexes:
|
||||||
CREATE TABLE IF NOT EXISTS schema_version (
|
try:
|
||||||
version INT PRIMARY KEY,
|
await cursor.execute(index_sql)
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
self.logger.debug(f"Created index: {index_sql}")
|
||||||
)
|
except MySQLError as e:
|
||||||
""")
|
if "Duplicate key name" not in str(e):
|
||||||
|
self.logger.warning(f"Failed to create index: {e}")
|
||||||
# Check schema version
|
|
||||||
await cursor.execute("""
|
# Add full-text search index
|
||||||
INSERT IGNORE INTO schema_version (version) VALUES (%s)
|
try:
|
||||||
""", (self.schema_version,))
|
await cursor.execute("""
|
||||||
|
ALTER TABLE users ADD FULLTEXT(username, display_name, bio)
|
||||||
|
""")
|
||||||
|
except MySQLError as e:
|
||||||
|
if "Duplicate key name" not in str(e):
|
||||||
|
self.logger.debug(f"Full-text index already exists or failed: {e}")
|
||||||
|
|
||||||
self.logger.info("Database tables verified/created")
|
self.logger.info("Database tables verified/created")
|
||||||
|
|
||||||
async def _execute_query(self, query: str, params: tuple = None):
|
|
||||||
"""Execute a database query with error handling."""
|
|
||||||
async with self._lock:
|
|
||||||
try:
|
|
||||||
async with self.pool.cursor() as cursor:
|
|
||||||
await cursor.execute(query, params)
|
|
||||||
return cursor
|
|
||||||
except MySQLError as e:
|
|
||||||
self.logger.error(f"Database error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_user(self, user_id: int) -> Optional[UserData]:
|
async def get_user(self, user_id: int) -> Optional[UserData]:
|
||||||
"""Get user data by ID with associated servers."""
|
"""Get user data by ID with associated servers."""
|
||||||
async with self.pool.cursor() as cursor:
|
async with self.pool.acquire() as conn:
|
||||||
await cursor.execute("""
|
async with conn.cursor() as cursor:
|
||||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
await cursor.execute("""
|
||||||
FROM users u
|
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
FROM users u
|
||||||
WHERE u.user_id = %s
|
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||||
GROUP BY u.user_id
|
WHERE u.user_id = %s
|
||||||
""", (user_id,))
|
GROUP BY u.user_id
|
||||||
|
""", (user_id,))
|
||||||
result = await cursor.fetchone()
|
|
||||||
if result:
|
result = await cursor.fetchone()
|
||||||
return self._parse_user_result(result)
|
if result:
|
||||||
return None
|
return self._parse_user_result(result)
|
||||||
|
return None
|
||||||
|
|
||||||
def _parse_user_result(self, result: Dict) -> UserData:
|
def _parse_user_result(self, result: Dict) -> UserData:
|
||||||
"""Convert database result to UserData object."""
|
"""Convert database result to UserData object."""
|
||||||
|
@ -400,64 +469,65 @@ class MariaDBDatabase:
|
||||||
|
|
||||||
async def save_user(self, user_data: UserData):
|
async def save_user(self, user_data: UserData):
|
||||||
"""Save or update user data with transaction."""
|
"""Save or update user data with transaction."""
|
||||||
async with self.pool.cursor() as cursor:
|
async with self.pool.acquire() as conn:
|
||||||
try:
|
async with conn.cursor() as cursor:
|
||||||
# Start transaction
|
try:
|
||||||
await cursor.execute("START TRANSACTION")
|
# Start transaction
|
||||||
|
await cursor.execute("START TRANSACTION")
|
||||||
# Upsert user data
|
|
||||||
await cursor.execute("""
|
# Upsert user data
|
||||||
INSERT INTO users (
|
await cursor.execute("""
|
||||||
user_id, username, discriminator, display_name,
|
INSERT INTO users (
|
||||||
avatar_url, banner_url, bio, status, activity
|
user_id, username, discriminator, display_name,
|
||||||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
avatar_url, banner_url, bio, status, activity
|
||||||
ON DUPLICATE KEY UPDATE
|
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
username = VALUES(username),
|
ON DUPLICATE KEY UPDATE
|
||||||
discriminator = VALUES(discriminator),
|
username = VALUES(username),
|
||||||
display_name = VALUES(display_name),
|
discriminator = VALUES(discriminator),
|
||||||
avatar_url = VALUES(avatar_url),
|
display_name = VALUES(display_name),
|
||||||
banner_url = VALUES(banner_url),
|
avatar_url = VALUES(avatar_url),
|
||||||
bio = VALUES(bio),
|
banner_url = VALUES(banner_url),
|
||||||
status = VALUES(status),
|
bio = VALUES(bio),
|
||||||
activity = VALUES(activity)
|
status = VALUES(status),
|
||||||
""", (
|
activity = VALUES(activity)
|
||||||
user_data.user_id,
|
""", (
|
||||||
user_data.username,
|
user_data.user_id,
|
||||||
user_data.discriminator,
|
user_data.username,
|
||||||
user_data.display_name,
|
user_data.discriminator,
|
||||||
user_data.avatar_url,
|
user_data.display_name,
|
||||||
user_data.banner_url,
|
user_data.avatar_url,
|
||||||
user_data.bio,
|
user_data.banner_url,
|
||||||
user_data.status,
|
user_data.bio,
|
||||||
user_data.activity
|
user_data.status,
|
||||||
))
|
user_data.activity
|
||||||
|
))
|
||||||
|
|
||||||
# Update servers relationship
|
# Update servers relationship
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
"DELETE FROM user_servers WHERE user_id = %s",
|
"DELETE FROM user_servers WHERE user_id = %s",
|
||||||
(user_data.user_id,)
|
(user_data.user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_data.servers:
|
if user_data.servers:
|
||||||
for server_id in user_data.servers:
|
for server_id in user_data.servers:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
"INSERT IGNORE INTO user_servers (user_id, server_id) VALUES (%s, %s)",
|
"INSERT IGNORE INTO user_servers (user_id, server_id) VALUES (%s, %s)",
|
||||||
(user_data.user_id, server_id)
|
(user_data.user_id, server_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Commit transaction
|
# Commit transaction
|
||||||
await cursor.execute("COMMIT")
|
await cursor.execute("COMMIT")
|
||||||
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}")
|
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator}")
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
self.logger.error(f"Error saving user: {e}")
|
self.logger.error(f"Error saving user: {e}")
|
||||||
await cursor.execute("ROLLBACK")
|
await cursor.execute("ROLLBACK")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def save_server(self, server_id: int, server_name: str):
|
async def save_server(self, server_id: int, server_name: str):
|
||||||
"""Save server information."""
|
"""Save server information."""
|
||||||
async with self._lock:
|
async with self.pool.acquire() as conn:
|
||||||
async with self.pool.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute("""
|
await cursor.execute("""
|
||||||
INSERT INTO servers (server_id, server_name)
|
INSERT INTO servers (server_id, server_name)
|
||||||
VALUES (%s, %s)
|
VALUES (%s, %s)
|
||||||
|
@ -471,8 +541,8 @@ class MariaDBDatabase:
|
||||||
if not server_ids:
|
if not server_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async with self._lock:
|
async with self.pool.acquire() as conn:
|
||||||
async with self.pool.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
placeholders = ','.join(['%s'] * len(server_ids))
|
placeholders = ','.join(['%s'] * len(server_ids))
|
||||||
await cursor.execute(f"""
|
await cursor.execute(f"""
|
||||||
SELECT server_id, server_name
|
SELECT server_id, server_name
|
||||||
|
@ -485,8 +555,8 @@ class MariaDBDatabase:
|
||||||
|
|
||||||
async def add_server_to_user(self, user_id: int, server_id: int):
|
async def add_server_to_user(self, user_id: int, server_id: int):
|
||||||
"""Add a server to user's server list."""
|
"""Add a server to user's server list."""
|
||||||
async with self._lock:
|
async with self.pool.acquire() as conn:
|
||||||
async with self.pool.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute("""
|
await cursor.execute("""
|
||||||
INSERT IGNORE INTO user_servers (user_id, server_id)
|
INSERT IGNORE INTO user_servers (user_id, server_id)
|
||||||
VALUES (%s, %s)
|
VALUES (%s, %s)
|
||||||
|
@ -494,42 +564,175 @@ class MariaDBDatabase:
|
||||||
|
|
||||||
async def get_all_users(self) -> List[UserData]:
|
async def get_all_users(self) -> List[UserData]:
|
||||||
"""Get all users from the database."""
|
"""Get all users from the database."""
|
||||||
async with self.pool.cursor() as cursor:
|
async with self.pool.acquire() as conn:
|
||||||
await cursor.execute("""
|
async with conn.cursor() as cursor:
|
||||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
await cursor.execute("""
|
||||||
FROM users u
|
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||||
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
FROM users u
|
||||||
GROUP BY u.user_id
|
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||||
""")
|
GROUP BY u.user_id
|
||||||
results = await cursor.fetchall()
|
""")
|
||||||
return [self._parse_user_result(r) for r in results]
|
results = await cursor.fetchall()
|
||||||
|
return [self._parse_user_result(r) for r in results]
|
||||||
|
|
||||||
|
async def get_users_paginated(self, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||||
|
"""Get users with pagination."""
|
||||||
|
async with self.pool.acquire() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
await cursor.execute("""
|
||||||
|
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||||
|
FROM users u
|
||||||
|
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||||
|
GROUP BY u.user_id
|
||||||
|
ORDER BY u.user_id
|
||||||
|
LIMIT %s OFFSET %s
|
||||||
|
""", (limit, offset))
|
||||||
|
results = await cursor.fetchall()
|
||||||
|
return [self._parse_user_result(r) for r in results]
|
||||||
|
|
||||||
|
async def search_users(self, query: str, offset: int = 0, limit: int = 100) -> List[UserData]:
|
||||||
|
"""Search users with pagination."""
|
||||||
|
async with self.pool.acquire() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
# Try full-text search first
|
||||||
|
try:
|
||||||
|
await cursor.execute("""
|
||||||
|
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||||
|
FROM users u
|
||||||
|
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||||
|
WHERE MATCH(u.username, u.display_name, u.bio) AGAINST(%s IN BOOLEAN MODE)
|
||||||
|
GROUP BY u.user_id
|
||||||
|
ORDER BY u.user_id
|
||||||
|
LIMIT %s OFFSET %s
|
||||||
|
""", (f"*{query}*", limit, offset))
|
||||||
|
results = await cursor.fetchall()
|
||||||
|
if results:
|
||||||
|
return [self._parse_user_result(r) for r in results]
|
||||||
|
except MySQLError:
|
||||||
|
# Fall back to LIKE search if full-text fails
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback to LIKE search
|
||||||
|
search_pattern = f"%{query}%"
|
||||||
|
await cursor.execute("""
|
||||||
|
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||||
|
FROM users u
|
||||||
|
LEFT JOIN user_servers us ON u.user_id = us.user_id
|
||||||
|
WHERE u.username LIKE %s OR u.display_name LIKE %s OR u.user_id = %s
|
||||||
|
GROUP BY u.user_id
|
||||||
|
ORDER BY u.user_id
|
||||||
|
LIMIT %s OFFSET %s
|
||||||
|
""", (search_pattern, search_pattern, query if query.isdigit() else None, limit, offset))
|
||||||
|
results = await cursor.fetchall()
|
||||||
|
return [self._parse_user_result(r) for r in results]
|
||||||
|
|
||||||
|
async def get_user_count_total(self) -> int:
|
||||||
|
"""Get total number of users."""
|
||||||
|
async with self.pool.acquire() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
await cursor.execute("SELECT COUNT(*) as count FROM users")
|
||||||
|
result = await cursor.fetchone()
|
||||||
|
return result['count'] if result else 0
|
||||||
|
|
||||||
|
async def save_users_batch(self, users: List[UserData]):
|
||||||
|
"""Save multiple users in a batch operation."""
|
||||||
|
if not users:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.pool.acquire() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
try:
|
||||||
|
await cursor.execute("START TRANSACTION")
|
||||||
|
|
||||||
|
# Prepare batch data for users
|
||||||
|
user_data = []
|
||||||
|
server_data = []
|
||||||
|
|
||||||
|
for user in users:
|
||||||
|
user_data.append((
|
||||||
|
user.user_id,
|
||||||
|
user.username,
|
||||||
|
user.discriminator,
|
||||||
|
user.display_name,
|
||||||
|
user.avatar_url,
|
||||||
|
user.banner_url,
|
||||||
|
user.bio,
|
||||||
|
user.status,
|
||||||
|
user.activity
|
||||||
|
))
|
||||||
|
|
||||||
|
# Collect server relationships
|
||||||
|
for server_id in user.servers:
|
||||||
|
server_data.append((user.user_id, server_id))
|
||||||
|
|
||||||
|
# Batch insert users
|
||||||
|
await cursor.executemany("""
|
||||||
|
INSERT INTO users (user_id, username, discriminator, display_name,
|
||||||
|
avatar_url, banner_url, bio, status, activity)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
|
ON DUPLICATE KEY UPDATE
|
||||||
|
username = VALUES(username),
|
||||||
|
discriminator = VALUES(discriminator),
|
||||||
|
display_name = VALUES(display_name),
|
||||||
|
avatar_url = VALUES(avatar_url),
|
||||||
|
banner_url = VALUES(banner_url),
|
||||||
|
bio = VALUES(bio),
|
||||||
|
status = VALUES(status),
|
||||||
|
activity = VALUES(activity)
|
||||||
|
""", user_data)
|
||||||
|
|
||||||
|
# Clear existing server relationships for these users
|
||||||
|
user_ids = [user.user_id for user in users]
|
||||||
|
if user_ids:
|
||||||
|
placeholders = ','.join(['%s'] * len(user_ids))
|
||||||
|
await cursor.execute(f"""
|
||||||
|
DELETE FROM user_servers WHERE user_id IN ({placeholders})
|
||||||
|
""", user_ids)
|
||||||
|
|
||||||
|
# Batch insert server relationships
|
||||||
|
if server_data:
|
||||||
|
await cursor.executemany("""
|
||||||
|
INSERT IGNORE INTO user_servers (user_id, server_id)
|
||||||
|
VALUES (%s, %s)
|
||||||
|
""", server_data)
|
||||||
|
|
||||||
|
await cursor.execute("COMMIT")
|
||||||
|
self.logger.info(f"Batch saved {len(users)} users")
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
await cursor.execute("ROLLBACK")
|
||||||
|
self.logger.error(f"Error in batch save: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def get_users_by_server(self, server_id: int) -> List[UserData]:
|
async def get_users_by_server(self, server_id: int) -> List[UserData]:
|
||||||
"""Get all users that are members of a specific server."""
|
"""Get all users that are members of a specific server."""
|
||||||
async with self.pool.cursor() as cursor:
|
async with self.pool.acquire() as conn:
|
||||||
await cursor.execute("""
|
async with conn.cursor() as cursor:
|
||||||
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
await cursor.execute("""
|
||||||
FROM users u
|
SELECT u.*, GROUP_CONCAT(us.server_id) AS servers
|
||||||
JOIN user_servers us ON u.user_id = us.user_id
|
FROM users u
|
||||||
WHERE us.server_id = %s
|
JOIN user_servers us ON u.user_id = us.user_id
|
||||||
GROUP BY u.user_id
|
WHERE us.server_id = %s
|
||||||
""", (server_id,))
|
GROUP BY u.user_id
|
||||||
results = await cursor.fetchall()
|
""", (server_id,))
|
||||||
return [self._parse_user_result(r) for r in results]
|
results = await cursor.fetchall()
|
||||||
|
return [self._parse_user_result(r) for r in results]
|
||||||
|
|
||||||
async def get_user_count(self) -> int:
|
async def get_user_count(self) -> int:
|
||||||
"""Get total number of users in database."""
|
"""Get total number of users in database."""
|
||||||
async with self.pool.cursor() as cursor:
|
async with self.pool.acquire() as conn:
|
||||||
await cursor.execute("SELECT COUNT(*) as user_count FROM users")
|
async with conn.cursor() as cursor:
|
||||||
result = await cursor.fetchone()
|
await cursor.execute("SELECT COUNT(*) as user_count FROM users")
|
||||||
return result['user_count'] if result else 0
|
result = await cursor.fetchone()
|
||||||
|
return result['user_count'] if result else 0
|
||||||
|
|
||||||
async def get_server_count(self) -> int:
|
async def get_server_count(self) -> int:
|
||||||
"""Get total number of unique servers."""
|
"""Get total number of unique servers."""
|
||||||
async with self.pool.cursor() as cursor:
|
async with self.pool.acquire() as conn:
|
||||||
await cursor.execute("SELECT COUNT(DISTINCT server_id) as server_count FROM user_servers")
|
async with conn.cursor() as cursor:
|
||||||
result = await cursor.fetchone()
|
await cursor.execute("SELECT COUNT(DISTINCT server_id) as server_count FROM user_servers")
|
||||||
return result['server_count'] if result else 0
|
result = await cursor.fetchone()
|
||||||
|
return result['server_count'] if result else 0
|
||||||
|
|
||||||
async def get_statistics(self) -> Dict[str, Any]:
|
async def get_statistics(self) -> Dict[str, Any]:
|
||||||
"""Get database statistics using optimized queries."""
|
"""Get database statistics using optimized queries."""
|
||||||
|
@ -538,30 +741,48 @@ class MariaDBDatabase:
|
||||||
'total_servers': await self.get_server_count(),
|
'total_servers': await self.get_server_count(),
|
||||||
}
|
}
|
||||||
|
|
||||||
async with self.pool.cursor() as cursor:
|
async with self.pool.acquire() as conn:
|
||||||
await cursor.execute("""
|
async with conn.cursor() as cursor:
|
||||||
SELECT server_id, COUNT(user_id) as user_count
|
await cursor.execute("""
|
||||||
FROM user_servers
|
SELECT server_id, COUNT(user_id) as user_count
|
||||||
GROUP BY server_id
|
FROM user_servers
|
||||||
ORDER BY user_count DESC
|
GROUP BY server_id
|
||||||
LIMIT 10
|
ORDER BY user_count DESC
|
||||||
""")
|
LIMIT 10
|
||||||
most_active = await cursor.fetchall()
|
""")
|
||||||
# Convert to list of tuples for consistency with JSON version
|
most_active = await cursor.fetchall()
|
||||||
stats['most_active_servers'] = [(row['server_id'], row['user_count']) for row in most_active]
|
# Convert to list of tuples for consistency with JSON version
|
||||||
|
stats['most_active_servers'] = [(row['server_id'], row['user_count']) for row in most_active]
|
||||||
# Get database size
|
|
||||||
await cursor.execute("""
|
# Get database size
|
||||||
SELECT
|
await cursor.execute("""
|
||||||
ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS database_size_mb
|
SELECT
|
||||||
FROM information_schema.tables
|
ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS database_size_mb
|
||||||
WHERE table_schema = DATABASE()
|
FROM information_schema.tables
|
||||||
""")
|
WHERE table_schema = DATABASE()
|
||||||
size_result = await cursor.fetchone()
|
""")
|
||||||
stats['database_size'] = int((size_result['database_size_mb'] or 0) * 1024 * 1024) # Convert to bytes
|
size_result = await cursor.fetchone()
|
||||||
|
stats['database_size'] = int((size_result['database_size_mb'] or 0) * 1024 * 1024) # Convert to bytes
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
async def export_to_csv(self, output_path: str):
|
||||||
|
"""Export data to CSV format."""
|
||||||
|
import csv
|
||||||
|
users = await self.get_all_users()
|
||||||
|
|
||||||
|
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||||
|
fieldnames = ['user_id', 'username', 'discriminator', 'display_name',
|
||||||
|
'avatar_url', 'banner_url', 'bio', 'status', 'activity',
|
||||||
|
'servers', 'created_at', 'updated_at']
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
|
||||||
|
writer.writeheader()
|
||||||
|
for user in users:
|
||||||
|
row = user.to_dict()
|
||||||
|
row['servers'] = ','.join(map(str, user.servers))
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
async def cleanup_old_backups(self, max_backups: int = 5):
|
async def cleanup_old_backups(self, max_backups: int = 5):
|
||||||
"""Clean up old backup documents (placeholder for MariaDB)."""
|
"""Clean up old backup documents (placeholder for MariaDB)."""
|
||||||
# For MariaDB, we could implement this as cleaning up old backup tables
|
# For MariaDB, we could implement this as cleaning up old backup tables
|
||||||
|
|
|
@ -9,6 +9,72 @@ from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveRateLimiter:
|
||||||
|
"""Adaptive rate limiter that adjusts to Discord API rate limits."""
|
||||||
|
|
||||||
|
def __init__(self, base_delay: float = 1.0, max_delay: float = 60.0):
|
||||||
|
self.base_delay = base_delay
|
||||||
|
self.max_delay = max_delay
|
||||||
|
self.current_delay = base_delay
|
||||||
|
self.consecutive_429s = 0
|
||||||
|
self.last_request_time: Optional[float] = None
|
||||||
|
self.success_count = 0
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
"""Wait with adaptive delay based on rate limit responses."""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Check if we need to wait based on current delay
|
||||||
|
if self.last_request_time is not None:
|
||||||
|
time_since_last_request = current_time - self.last_request_time
|
||||||
|
if time_since_last_request < self.current_delay:
|
||||||
|
wait_time = self.current_delay - time_since_last_request
|
||||||
|
self.logger.debug(f"Adaptive rate limit wait: {wait_time:.2f}s")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
|
self.last_request_time = time.time()
|
||||||
|
|
||||||
|
def on_success(self):
|
||||||
|
"""Called when a request succeeds."""
|
||||||
|
self.success_count += 1
|
||||||
|
|
||||||
|
# After 10 successful requests, try to reduce delay
|
||||||
|
if self.success_count >= 10:
|
||||||
|
old_delay = self.current_delay
|
||||||
|
self.current_delay = max(self.base_delay, self.current_delay * 0.9)
|
||||||
|
self.consecutive_429s = 0
|
||||||
|
self.success_count = 0
|
||||||
|
if old_delay != self.current_delay:
|
||||||
|
self.logger.debug(f"Reduced delay to {self.current_delay:.2f}s")
|
||||||
|
|
||||||
|
def on_rate_limit(self, retry_after: Optional[float] = None):
|
||||||
|
"""Called when a 429 rate limit is encountered."""
|
||||||
|
self.consecutive_429s += 1
|
||||||
|
self.success_count = 0
|
||||||
|
|
||||||
|
if retry_after:
|
||||||
|
# Use Discord's suggested retry time
|
||||||
|
self.current_delay = min(self.max_delay, retry_after)
|
||||||
|
self.logger.warning(f"Rate limited, using Discord's retry_after: {retry_after}s")
|
||||||
|
else:
|
||||||
|
# Exponential backoff
|
||||||
|
old_delay = self.current_delay
|
||||||
|
self.current_delay = min(self.max_delay, self.base_delay * (2 ** self.consecutive_429s))
|
||||||
|
self.logger.warning(f"Rate limited, increased delay from {old_delay:.2f}s to {self.current_delay:.2f}s")
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get rate limiter statistics."""
|
||||||
|
current_time = time.time()
|
||||||
|
return {
|
||||||
|
'current_delay': self.current_delay,
|
||||||
|
'base_delay': self.base_delay,
|
||||||
|
'consecutive_429s': self.consecutive_429s,
|
||||||
|
'success_count': self.success_count,
|
||||||
|
'time_since_last_request': current_time - self.last_request_time if self.last_request_time else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
class RateLimiter:
|
||||||
"""Rate limiter to prevent hitting Discord API limits."""
|
"""Rate limiter to prevent hitting Discord API limits."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue