fixed stuff
This commit is contained in:
parent
db6b0a996b
commit
22f59ea2ad
|
@ -155,16 +155,26 @@ class DiscordDataClient(discord.Client):
|
|||
try:
|
||||
# Get all members - discord.py-self API
|
||||
members = []
|
||||
member_iterator = await guild.fetch_members()
|
||||
async for member in member_iterator:
|
||||
members.append(member)
|
||||
member_result = await guild.fetch_members()
|
||||
|
||||
# Handle different return types from fetch_members
|
||||
if hasattr(member_result, '__aiter__'):
|
||||
# It's an async iterator
|
||||
async for member in member_result:
|
||||
members.append(member)
|
||||
elif isinstance(member_result, list):
|
||||
# It's already a list
|
||||
members = member_result
|
||||
else:
|
||||
# Try to iterate over it
|
||||
for member in member_result:
|
||||
members.append(member)
|
||||
|
||||
for member in members:
|
||||
if not member.bot:
|
||||
await self._process_user(member, guild.id)
|
||||
|
||||
# Rate limiting
|
||||
# Rate limiting before processing
|
||||
await self.rate_limiter.wait()
|
||||
await self._process_user(member, guild.id)
|
||||
|
||||
self.logger.info(f"Processed {len(members)} members from {guild.name}")
|
||||
|
||||
|
@ -224,6 +234,9 @@ class DiscordDataClient(discord.Client):
|
|||
async def _get_mutual_guilds(self, user) -> List[int]:
|
||||
"""Get mutual guilds for a user using fetch_user_profile."""
|
||||
try:
|
||||
# Rate limiting before API call
|
||||
await self.rate_limiter.wait()
|
||||
|
||||
# Use fetch_user_profile to get mutual guilds
|
||||
profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import json
|
|||
import shutil
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
|
@ -35,13 +35,11 @@ class UserData:
|
|||
bio: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
activity: Optional[str] = None
|
||||
servers: List[int] = None
|
||||
servers: List[int] = field(default_factory=list)
|
||||
created_at: datetime = None
|
||||
updated_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.servers is None:
|
||||
self.servers = []
|
||||
current_time = datetime.utcnow()
|
||||
if self.created_at is None:
|
||||
self.created_at = current_time
|
||||
|
@ -458,38 +456,41 @@ class MariaDBDatabase:
|
|||
|
||||
async def save_server(self, server_id: int, server_name: str):
|
||||
"""Save server information."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
INSERT INTO servers (server_id, server_name)
|
||||
VALUES (%s, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
server_name = VALUES(server_name),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
""", (server_id, server_name))
|
||||
async with self._lock:
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
INSERT INTO servers (server_id, server_name)
|
||||
VALUES (%s, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
server_name = VALUES(server_name),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
""", (server_id, server_name))
|
||||
|
||||
async def get_server_names(self, server_ids: List[int]) -> Dict[int, str]:
|
||||
"""Get server names for given server IDs."""
|
||||
if not server_ids:
|
||||
return {}
|
||||
|
||||
async with self.pool.cursor() as cursor:
|
||||
placeholders = ','.join(['%s'] * len(server_ids))
|
||||
await cursor.execute(f"""
|
||||
SELECT server_id, server_name
|
||||
FROM servers
|
||||
WHERE server_id IN ({placeholders})
|
||||
""", server_ids)
|
||||
|
||||
result = await cursor.fetchall()
|
||||
return {row['server_id']: row['server_name'] for row in result}
|
||||
async with self._lock:
|
||||
async with self.pool.cursor() as cursor:
|
||||
placeholders = ','.join(['%s'] * len(server_ids))
|
||||
await cursor.execute(f"""
|
||||
SELECT server_id, server_name
|
||||
FROM servers
|
||||
WHERE server_id IN ({placeholders})
|
||||
""", server_ids)
|
||||
|
||||
result = await cursor.fetchall()
|
||||
return {row['server_id']: row['server_name'] for row in result}
|
||||
|
||||
async def add_server_to_user(self, user_id: int, server_id: int):
|
||||
"""Add a server to user's server list."""
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
INSERT IGNORE INTO user_servers (user_id, server_id)
|
||||
VALUES (%s, %s)
|
||||
""", (user_id, server_id))
|
||||
async with self._lock:
|
||||
async with self.pool.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
INSERT IGNORE INTO user_servers (user_id, server_id)
|
||||
VALUES (%s, %s)
|
||||
""", (user_id, server_id))
|
||||
|
||||
async def get_all_users(self) -> List[UserData]:
|
||||
"""Get all users from the database."""
|
||||
|
|
Loading…
Reference in a new issue