discordb/cli.py

346 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Command-line interface for Discord Data Collector.
"""
import argparse
import asyncio
import json
import sys
from pathlib import Path
# Add src to path
sys.path.append(str(Path(__file__).parent))
from src.config import Config
from src.database import create_database
from src.client import DiscordDataClient
async def export_data(format_type: str, output_path: str = None):
"""Export collected data."""
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
if output_path is None:
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"data/export_{timestamp}.{format_type}"
if format_type == "csv":
await database.export_to_csv(output_path)
print(f"Data exported to {output_path}")
else:
print(f"Unsupported format: {format_type}")
async def show_stats():
"""Show database statistics."""
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
stats = await database.get_statistics()
print("\n=== Database Statistics ===")
print(f"Total users: {stats['total_users']}")
print(f"Total servers: {stats['total_servers']}")
print(f"Database size: {stats['database_size']} bytes")
if stats['most_active_servers']:
print("\nMost active servers:")
for server_id, user_count in stats['most_active_servers'][:5]:
print(f" Server {server_id}: {user_count} users")
async def search_user(query: str, page: int = 1, per_page: int = 10):
"""Search for users with pagination."""
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
offset = (page - 1) * per_page
results = await database.search_users(query, offset, per_page)
if not results:
print("No users found matching the query.")
return
total_count = await database.get_user_count_total()
total_pages = (total_count + per_page - 1) // per_page
print(f"\n=== Search Results (Page {page} of {total_pages}) ===")
for user in results:
print(f"{user.username}#{user.discriminator} (ID: {user.user_id})")
if user.display_name:
print(f" Display name: {user.display_name}")
if user.bio:
print(f" Bio: {user.bio[:100]}{'...' if len(user.bio) > 100 else ''}")
if user.status:
print(f" Status: {user.status}")
if user.activity:
print(f" Activity: {user.activity[:50]}{'...' if len(user.activity) > 50 else ''}")
# Get server names for display
if user.servers:
server_names = await database.get_server_names(user.servers)
server_display = []
for server_id in user.servers[:5]:
server_name = server_names.get(server_id, f"Unknown ({server_id})")
server_display.append(server_name)
more_text = "..." if len(user.servers) > 5 else ""
print(f" Servers ({len(user.servers)}): {', '.join(server_display)}{more_text}")
else:
print(f" Servers: None")
print(f" Last updated: {user.updated_at}")
print()
# Add pagination navigation info
print(f"\nShowing {len(results)} of {total_count} total users")
if total_pages > 1:
print(f"Use --page {page + 1} for next page" if page < total_pages else "Last page")
await database.close()
async def list_servers():
"""List all servers with user counts."""
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
# Get all users
users = await database.get_all_users()
# Count users per server
server_counts = {}
for user in users:
for server_id in user.servers:
server_counts[server_id] = server_counts.get(server_id, 0) + 1
if not server_counts:
print("No servers found in database.")
return
# Sort by user count (descending)
sorted_servers = sorted(server_counts.items(), key=lambda x: x[1], reverse=True)
# Get server names
server_ids = [server_id for server_id, _ in sorted_servers]
server_names = await database.get_server_names(server_ids)
print(f"\n=== Servers ({len(sorted_servers)} total) ===")
for server_id, user_count in sorted_servers:
server_name = server_names.get(server_id, f"Unknown Server ({server_id})")
print(f"{server_name}: {user_count} users")
await database.close()
async def show_user_servers(query: str):
"""Show servers a user is in."""
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
try:
user = None
# Try to parse as user ID first
try:
user_id_int = int(query)
user = await database.get_user(user_id_int)
except ValueError:
# If not a number, search by username
all_users = await database.get_all_users()
for u in all_users:
if query.lower() in u.username.lower() or query.lower() in (u.display_name or "").lower():
user = u
break
if not user:
print(f"User '{query}' not found in database.")
return
print(f"\n=== User: {user.username}#{user.discriminator} ===")
print(f"User ID: {user.user_id}")
if user.display_name:
print(f"Display Name: {user.display_name}")
if user.bio:
print(f"Bio: {user.bio[:100]}{'...' if len(user.bio) > 100 else ''}")
if user.status:
print(f"Status: {user.status}")
if user.activity:
print(f"Activity: {user.activity}")
# Get server names for display
if user.servers:
server_names = await database.get_server_names(user.servers)
print(f"\nServers ({len(user.servers)}):")
for server_id in user.servers:
server_name = server_names.get(server_id, f"Unknown Server ({server_id})")
print(f" - {server_name}")
else:
print(f"\nServers: None")
finally:
await database.close()
async def show_server_users(server_id: str):
"""Show users in a server."""
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
try:
server_id_int = int(server_id)
users = await database.get_users_by_server(server_id_int)
if not users:
print(f"No users found for server {server_id}.")
return
print(f"\n=== Server {server_id} ({len(users)} users) ===")
# Sort users by username
users.sort(key=lambda u: u.username.lower())
for user in users:
status_info = ""
if user.status:
status_info = f" [{user.status}]"
if user.activity:
status_info += f" ({user.activity[:30]}{'...' if len(user.activity) > 30 else ''})"
print(f"{user.username}#{user.discriminator} (ID: {user.user_id}){status_info}")
if user.display_name and user.display_name != user.username:
print(f" Display: {user.display_name}")
if user.bio:
print(f" Bio: {user.bio[:80]}{'...' if len(user.bio) > 80 else ''}")
print()
except ValueError:
print("Invalid server ID. Please provide a numeric server ID.")
finally:
await database.close()
async def backup_database():
"""Create a manual backup of the database."""
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"data/backups/manual_backup_{timestamp}.json"
# Export all data to JSON for backup
users = await database.get_all_users()
backup_data = [user.to_dict() for user in users]
# Ensure backup directory exists
Path("data/backups").mkdir(parents=True, exist_ok=True)
with open(backup_path, 'w', encoding='utf-8') as f:
json.dump(backup_data, f, indent=2, ensure_ascii=False)
print(f"Database backed up to {backup_path}")
await database.close()
async def cleanup_data():
"""Clean up old data and backups."""
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
await database.cleanup_old_backups(max_backups=5)
print("Cleanup completed")
await database.close()
async def test_connection():
"""Test Discord connection."""
try:
config = Config()
database = await create_database(mariadb_config=config.get_mariadb_config())
client = DiscordDataClient(config, database)
print("Testing Discord connection...")
# This will test the connection without starting the full bot
await client.login(config.discord_token)
user_info = client.user
print(f"✓ Successfully connected as {user_info.name}#{user_info.discriminator}")
print(f"✓ User ID: {user_info.id}")
await client.close()
except Exception as e:
print(f"✗ Connection failed: {e}")
sys.exit(1)
def main():
"""Main CLI entry point."""
parser = argparse.ArgumentParser(description="Discord Data Collector CLI")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Export command
export_parser = subparsers.add_parser("export", help="Export collected data")
export_parser.add_argument("format", choices=["csv"], help="Export format")
export_parser.add_argument("-o", "--output", help="Output file path")
# Stats command
subparsers.add_parser("stats", help="Show database statistics")
# Search command
search_parser = subparsers.add_parser("search", help="Search for users")
search_parser.add_argument("query", help="Search query (username or user ID)")
search_parser.add_argument("--page", type=int, default=1, help="Page number (default: 1)")
search_parser.add_argument("--per-page", type=int, default=10, help="Results per page (default: 10)")
# Server commands
servers_parser = subparsers.add_parser("servers", help="List all servers with user counts")
user_servers_parser = subparsers.add_parser("user-servers", help="Show servers a user is in")
user_servers_parser.add_argument("query", help="User ID or username to lookup")
server_users_parser = subparsers.add_parser("server-users", help="Show users in a server")
server_users_parser.add_argument("server_id", help="Server ID to lookup")
# Backup command
subparsers.add_parser("backup", help="Create manual database backup")
# Cleanup command
subparsers.add_parser("cleanup", help="Clean up old data and backups")
# Test command
subparsers.add_parser("test", help="Test Discord connection")
args = parser.parse_args()
if not args.command:
parser.print_help()
return
# Run the appropriate command
if args.command == "export":
asyncio.run(export_data(args.format, args.output))
elif args.command == "stats":
asyncio.run(show_stats())
elif args.command == "search":
asyncio.run(search_user(args.query, args.page, getattr(args, 'per_page', 10)))
elif args.command == "servers":
asyncio.run(list_servers())
elif args.command == "user-servers":
asyncio.run(show_user_servers(args.query))
elif args.command == "server-users":
asyncio.run(show_server_users(args.server_id))
elif args.command == "backup":
asyncio.run(backup_database())
elif args.command == "cleanup":
asyncio.run(cleanup_data())
elif args.command == "test":
asyncio.run(test_connection())
if __name__ == "__main__":
main()