From 22f59ea2ad0669d85bace8c17a7e1e9a99001297 Mon Sep 17 00:00:00 2001 From: Xargana Date: Mon, 14 Jul 2025 12:59:57 +0300 Subject: [PATCH] fixed stuff --- src/client.py | 25 ++++++++++++++++------ src/database.py | 55 +++++++++++++++++++++++++------------------------ 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/src/client.py b/src/client.py index 82d6029..2fb8025 100644 --- a/src/client.py +++ b/src/client.py @@ -155,16 +155,26 @@ class DiscordDataClient(discord.Client): try: # Get all members - discord.py-self API members = [] - member_iterator = await guild.fetch_members() - async for member in member_iterator: - members.append(member) + member_result = await guild.fetch_members() + + # Handle different return types from fetch_members + if hasattr(member_result, '__aiter__'): + # It's an async iterator + async for member in member_result: + members.append(member) + elif isinstance(member_result, list): + # It's already a list + members = member_result + else: + # Try to iterate over it + for member in member_result: + members.append(member) for member in members: if not member.bot: - await self._process_user(member, guild.id) - - # Rate limiting + # Rate limiting before processing await self.rate_limiter.wait() + await self._process_user(member, guild.id) self.logger.info(f"Processed {len(members)} members from {guild.name}") @@ -224,6 +234,9 @@ class DiscordDataClient(discord.Client): async def _get_mutual_guilds(self, user) -> List[int]: """Get mutual guilds for a user using fetch_user_profile.""" try: + # Rate limiting before API call + await self.rate_limiter.wait() + # Use fetch_user_profile to get mutual guilds profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True) diff --git a/src/database.py b/src/database.py index 632e171..745887a 100644 --- a/src/database.py +++ b/src/database.py @@ -7,7 +7,7 @@ import json import shutil from datetime import datetime from typing import Dict, List, Optional, Any -from dataclasses import dataclass, asdict +from dataclasses import dataclass, asdict, field from pathlib import Path import logging @@ -35,13 +35,11 @@ class UserData: bio: Optional[str] = None status: Optional[str] = None activity: Optional[str] = None - servers: List[int] = None + servers: List[int] = field(default_factory=list) created_at: datetime = None updated_at: datetime = None def __post_init__(self): - if self.servers is None: - self.servers = [] current_time = datetime.utcnow() if self.created_at is None: self.created_at = current_time @@ -458,38 +456,41 @@ class MariaDBDatabase: async def save_server(self, server_id: int, server_name: str): """Save server information.""" - async with self.pool.cursor() as cursor: - await cursor.execute(""" - INSERT INTO servers (server_id, server_name) - VALUES (%s, %s) - ON DUPLICATE KEY UPDATE - server_name = VALUES(server_name), - updated_at = CURRENT_TIMESTAMP - """, (server_id, server_name)) + async with self._lock: + async with self.pool.cursor() as cursor: + await cursor.execute(""" + INSERT INTO servers (server_id, server_name) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE + server_name = VALUES(server_name), + updated_at = CURRENT_TIMESTAMP + """, (server_id, server_name)) async def get_server_names(self, server_ids: List[int]) -> Dict[int, str]: """Get server names for given server IDs.""" if not server_ids: return {} - async with self.pool.cursor() as cursor: - placeholders = ','.join(['%s'] * len(server_ids)) - await cursor.execute(f""" - SELECT server_id, server_name - FROM servers - WHERE server_id IN ({placeholders}) - """, server_ids) - - result = await cursor.fetchall() - return {row['server_id']: row['server_name'] for row in result} + async with self._lock: + async with self.pool.cursor() as cursor: + placeholders = ','.join(['%s'] * len(server_ids)) + await cursor.execute(f""" + SELECT server_id, server_name + FROM servers + WHERE server_id IN ({placeholders}) + """, server_ids) + + result = await cursor.fetchall() + return {row['server_id']: row['server_name'] for row in result} async def add_server_to_user(self, user_id: int, server_id: int): """Add a server to user's server list.""" - async with self.pool.cursor() as cursor: - await cursor.execute(""" - INSERT IGNORE INTO user_servers (user_id, server_id) - VALUES (%s, %s) - """, (user_id, server_id)) + async with self._lock: + async with self.pool.cursor() as cursor: + await cursor.execute(""" + INSERT IGNORE INTO user_servers (user_id, server_id) + VALUES (%s, %s) + """, (user_id, server_id)) async def get_all_users(self) -> List[UserData]: """Get all users from the database."""