major update, switch over to asyncpg

parent 58adbca6
......@@ -31,7 +31,7 @@ import psutil
from discord.ext import commands
from sqlalchemy import __version__ as sqlalchemy_version
from .utils import sql
#from .utils import sql
from .utils.colors import paint
from .utils.logger import get_logger
......@@ -45,105 +45,6 @@ class Analytics:
# Logging
# Messages
async def log_message(self, message):
author = message.author
channel = message.channel
guild = channel.guild
with self.bot.db.session() as session:
msg_author = session.query(sql.User).filter_by(id=author.id).first()
if msg_author is None:
msg_author = sql.User(
id=author.id,
name=author.name,
bot=author.bot,
discrim=author.discriminator
)
session.add(msg_author)
if msg_author.name != author.name:
msg_author.name = author.name
if msg_author.discrim != author.discriminator:
msg_author.discrim = author.discriminator
new_message = sql.Message(
id=message.id,
timestamp=message.created_at,
author_id=author.id,
channel_id=channel.id,
guild_id=guild.id,
content=message.content
)
session.add(new_message)
# Updated messages
async def log_message_change(self, message, deleted=False):
author = message.author
channel = message.channel
guild = channel.guild
with self.bot.db.session() as session:
msg_author = session.query(sql.User).filter_by(id=author.id).first()
if msg_author is None:
msg_author = sql.User(
id=author.id,
name=author.name,
bot=author.bot,
discrim=author.discriminator
)
session.add(msg_author)
if msg_author.name != author.name:
msg_author.name = author.name
if msg_author.discrim != author.discriminator:
msg_author.discrim = author.discriminator
new_message = sql.MessageChange(
id=message.id,
timestamp=message.created_at,
author_id=author.id,
channel_id=channel.id,
guild_id=guild.id,
content=message.content,
deleted=deleted
)
session.add(new_message)
# Commands
async def log_command_use(self, ctx):
user = ctx.author
command_name = ctx.command.qualified_name
message = ctx.message
with self.bot.db.session() as session:
command_author = session.query(sql.User).filter_by(id=user.id).first()
if command_author is None:
command_author = sql.User(
id=user.id,
name=user.name,
bot=user.bot,
discrim=user.discriminator
)
session.add(command_author)
command = sql.Command(
message_id=message.id,
command_name=command_name,
user_id=user.id,
timestamp=ctx.message.created_at,
args=message.clean_content.split(ctx.invoked_with)[1].strip(),
errored=False
)
session.add(command)
# Socket data
def log_socket_data(self, data):
if "t" in data:
......@@ -164,47 +65,47 @@ class Analytics:
self.log_socket_data(payload)
# Command was triggered
async def on_command(self, ctx):
message = ctx.message
channel = ctx.channel
author = ctx.author
destination = None
# async def on_command(self, ctx):
# message = ctx.message
# channel = ctx.channel
# author = ctx.author
# destination = None
if not hasattr(channel, "guild"):
destination = "Private Message"
# if not hasattr(channel, "guild"):
# destination = "Private Message"
else:
destination = f"[{ctx.guild.name} #{channel.name}]"
# else:
# destination = f"[{ctx.guild.name} #{channel.name}]"
log.info(f"{destination}: {author.name}: {message.clean_content}")
# log.info(f"{destination}: {author.name}: {message.clean_content}")
await self.log_command_use(ctx)
# await self.log_command_use(ctx)
# Message arrived
async def on_message(self, message):
channel = message.channel
author = message.author
# async def on_message(self, message):
# channel = message.channel
# author = message.author
if author.bot or (not self.bot.is_ready()):
return
# if author.bot or (not self.bot.is_ready()):
# return
if hasattr(author, "display_name"):
await self.log_message(message)
# if hasattr(author, "display_name"):
# await self.log_message(message)
# Message deleted
async def on_message_delete(self, message):
channel = message.channel
author = message.author
if hasattr(channel, "guild") and hasattr(author, "display_name"):
await self.log_message_change(message, deleted=True)
# async def on_message_delete(self, message):
# channel = message.channel
# author = message.author
# if hasattr(channel, "guild") and hasattr(author, "display_name"):
# await self.log_message_change(message, deleted=True)
# Messaged edited
async def on_message_edit(self, old_message, new_message):
channel = new_message.channel
author = new_message.author
if old_message.content != new_message.content:
if hasattr(channel, "guild") and hasattr(author, "display_name"):
await self.log_message_change(new_message, deleted=False)
# async def on_message_edit(self, old_message, new_message):
# channel = new_message.channel
# author = new_message.author
# if old_message.content != new_message.content:
# if hasattr(channel, "guild") and hasattr(author, "display_name"):
# await self.log_message_change(new_message, deleted=False)
# Coommand tossed an error
async def on_command_error(self, ctx, error):
......@@ -238,41 +139,7 @@ class Analytics:
else:
print(f"{paint(type(error).__name__, 'b_red')}: {error}")
if ctx.command_failed:
with self.bot.db.session() as session:
command = session.query(sql.Command).filter_by(message_id=ctx.message.id).first()
if command is not None:
command.errored = True
# Commands
# Get process info
@commands.command(name="info", brief="bot info")
async def get_info(self, ctx):
last_commit = await self.bot.loop.run_in_executor(None, functools.partial(subprocess.run, "git log --pretty=format:\"%h by %an %ar (%s)\" -n 1", stdout=subprocess.PIPE, shell=True, universal_newlines=True))
process = psutil.Process()
output = discord.Embed(title="Information", color=0xFF8F00, description=f"Latest commit: **{last_commit.stdout}**\n\n[Gitlab Repo](https://gitlab.a-sketchy.site/AnonymousDapper/snake)")
output.add_field(name="Python", value=f"{platform.python_implementation()} {platform.python_version()}", inline=False)
output.add_field(name="Discord.py", value=discord.__version__, inline=False)
output.add_field(name="System", value=f"{platform.system()} {platform.machine()}", inline=False)
output.add_field(name="Kernel", value=platform.release(), inline=False)
with process.oneshot():
proc_mem_info = process.memory_full_info()
output.add_field(name="Used Memory (uss)", value=f"{int(proc_mem_info.uss / 1024 / 1024)}Mb", inline=False)
output.add_field(name="Used Memory (vms)", value=f"{int(proc_mem_info.vms / 1024 / 1024)}Mb", inline=False)
postgres_major, postgres_minor = self.bot.db.engine.dialect.server_version_info
output.add_field(name="PostgreSQL", value=f"{postgres_major}.{postgres_minor}", inline=False)
output.add_field(name="Database Driver", value=f"SQLAlchemy {sqlalchemy_version} with {self.bot.db.db_api}", inline=False)
await ctx.send(embed=output)
def setup(bot):
bot.add_cog(Analytics(bot))
......@@ -30,8 +30,9 @@ from functools import partial
from io import StringIO
from types import BuiltinFunctionType
import asyncpg
import discord
import sqlalchemy
from discord.ext import commands
......@@ -173,7 +174,7 @@ class Debug:
def get_info(self, result):
data = repr(result)
info = []
info = []
info.append(("Type", type(result).__name__))
info.append(("Memory", hex(id(result))))
......@@ -295,36 +296,25 @@ class Debug:
sql += ";"
try:
results = await self.bot.loop.run_in_executor(None, partial(self.bot.db.engine.execute, sql))
async with self.bot.db.pool.acquire() as conn:
results = await conn.fetch(sql)
except sqlalchemy.exc.ProgrammingError as e:
await ctx.send(f"```diff\n- {e.orig.message}\n```\n{e.orig.details.get('hint', 'Unknown fix')}\n\nDouble check your query:\n```sql\n{e.statement}\n{' ' * (int(e.orig.details.get('position', '0')) - 1)}^\n```")
except asyncpg.exceptions.PostgresSyntaxError as e:
await ctx.send(f"```diff\n- {e.message}\n```\n\nDouble check your query:\n```sql\n{e.query}\n```")
return
except Exception as e:
await ctx.send(f"```diff\n- {type(e).__name__}: {e}\n```")
return
if not results.returns_rows:
if len(results) == 0:
await self.bot.post_reaction(ctx.message, success=True)
else:
result_list = results.fetchall()
if len(result_list) == 0:
await ctx.send("\N{WARNING SIGN} Query returned 0 rows")
else:
row_names = results.keys()
# format a list of items
clr = lambda arr: ", ".join(str(item) for item in arr)
# f-string to format total result lsit
result = f"```md\n# Columns: {', '.join(row_names)}\n# {len(result_list)} total rows\n\n{self.NL.join('- ' + clr(arg) for arg in result_list)}\n```"
# I'm sorry
result = f"```md\n# Columns: {', '.join(results[0].keys())}\n# {len(results)} total rows\n\n{self.NL.join('- ' + (', '.join(str(item) for item in arg.values())) for arg in results)}\n```"
await ctx.send(await self.check_length(result))
await ctx.send(await self.check_length(result))
# Run shell commands
@commands.command(name="sh", brief="system terminal")
......
# MIT License
#
# Copyright (c) 2018 AnonymousDapper
#
......
......@@ -22,6 +22,7 @@
__all__ = ["get_logger", "set_level", "set_database"]
import asyncio
import inspect
import logging
import os
......@@ -29,8 +30,6 @@ import os
from datetime import datetime
from logging import handlers, Handler
from .sql import ErrorLog
# Make sure the log directory exists (and create it if not)
if not os.path.exists("logs"):
os.makedirs("logs")
......@@ -40,21 +39,12 @@ LOG_LEVEL = logging.INFO
class PostgresHandler(Handler):
def __init__(self, db):
self.db = db
self.loop = asyncio.get_event_loop()
super().__init__(logging.WARNING)
def emit(self, record):
with self.db.session() as session:
error_log = ErrorLog(
level=record.levelname,
module=record.module,
function=record.funcName,
filename=record.filename,
line_number=record.lineno,
message=record.msg,
timestamp=datetime.fromtimestamp(record.created)
)
session.add(error_log)
self.loop.create_task(self.db.create_error_report(record))
# Handlers
DATABASE_HANDLER = None # setup on init
......
......@@ -172,7 +172,7 @@ class MathParser:
elif isinstance(node, ast.BinOp):
op_name = type(node.op)
if op_name in self.operators:
if op_name in self.operators:
left_op = node.left
right_op = node.right
......
......@@ -20,208 +20,209 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
__all__ = ["Tag", "Permission", "User", "Blacklist", "Whitelist", "TagVariable", "Message", "MessageChange", "Command", "ErrorLog", "Prefix", "SQL"]
import asyncio
import traceback
from contextlib import contextmanager
from sqlalchemy import ForeignKey, Integer, BigInteger, String, DateTime, Boolean, Column, create_engine
from datetime import datetime
import asyncpg
from .logger import get_logger
log = get_logger()
USERS_TABLE = """
CREATE TABLE IF NOT EXISTS users (
id BIGINT NOT NULL,
name VARCHAR(40),
bot BOOLEAN,
discrim VARCHAR(4),
CONSTRAINT users_pk PRIMARY KEY (id),
UNIQUE (id)
);
"""
MESSAGES_TABLE = """
CREATE TABLE IF NOT EXISTS chat_logs (
id BIGINT NOT NULL,
timestamp TIMESTAMP,
author_id BIGINT,
channel_id BIGINT,
guild_id BIGINT,
content VARCHAR(2000),
CONSTRAINT chat_logs_pk PRIMARY KEY (id),
UNIQUE (id),
FOREIGN KEY(author_id) REFERENCES users (id)
);
"""
TAGS_TABLE = """
CREATE TABLE IF NOT EXISTS tags (
name VARCHAR(50) NOT NULL,
author_id BIGINT,
content VARCHAR(1950),
uses INTEGER,
timestamp TIMESTAMP,
data JSONB,
CONSTRAINT tags_pk PRIMARY KEY (name),
UNIQUE (name),
FOREIGN KEY(author_id) REFERENCES users (id)
);
"""
PREFIX_TABLE = """
CREATE TABLE IF NOT EXISTS prefixes (
id BIGINT NOT NULL,
personal BOOLEAN,
prefix VARCHAR(32),
CONSTRAINT prefixes_pk PRIMARY KEY (id),
UNIQUE (id)
);
"""
PERMISSION_TABLE = """
CREATE TABLE IF NOT EXISTS permissions (
pk SERIAL NOT NULL,
guild_id BIGINT,
channel_id BIGINT,
role_id BIGINT,
user_id BIGINT,
bits INTEGER,
CONSTRAINT permissions_pk PRIMARY KEY (pk),
UNIQUE (pk)
);
"""
BLACKLIST_TABLE = """
CREATE TABLE IF NOT EXISTS blacklist (
pk SERIAL NOT NULL,
guild_id BIGINT,
channel_id BIGINT,
role_id BIGINT,
user_id BIGINT,
data VARCHAR(2000),
CONSTRAINT blacklist_pk PRIMARY KEY (pk),
UNIQUE (pk)
);
"""
WHITELIST_TABLE = """
CREATE TABLE IF NOT EXISTS whitelist (
pk SERIAL NOT NULL,
guild_id BIGINT,
channel_id BIGINT,
role_id BIGINT,
user_id BIGINT,
data VARCHAR(2000),
CONSTRAINT whitelist_pk PRIMARY KEY (pk),
UNIQUE (pk)
);
"""
ERRORS_TABLE = """
CREATE TABLE IF NOT EXISTS logged_errors (
pk SERIAL NOT NULL,
level VARCHAR(30),
module VARCHAR(2000),
function VARCHAR(2000),
filename VARCHAR(2000),
line INTEGER,
message VARCHAR(2000),
timestamp TIMESTAMP,
CONSTRAINT logged_errors_pk PRIMARY KEY (pk),
UNIQUE (pk)
);
"""
STATS_TABLE = """
CREATE TABLE IF NOT EXISTS command_stats (
pk SERIAL NOT NULL,
message_id BIGINT,
command_name VARCHAR(40),
user_id BIGINT,
args VARCHAR(2000),
errored BOOLEAN,
CONSTRAINT command_stats_pk PRIMARY KEY (pk),
UNIQUE (pk),
FOREIGN KEY(message_id) REFERENCES chat_logs (id),
FOREIGN KEY(user_id) REFERENCES users (id)
);
"""
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.orm.attributes import flag_modified
# Can't import logger here because logger:PostgresHandler refers to ErrorLog and causes import loop
Base = declarative_base()
# Class -> table mappings
class Tag(Base):
__tablename__ = "tags"
name = Column(String(50), primary_key=True, unique=True)
author_id = Column(BigInteger, ForeignKey("users.id"))
author = relationship("User", back_populates="tags")
content = Column(String(2000))
uses = Column(Integer)
timestamp = Column(DateTime)
def __repr__(self):
return f"<Tag(name='{self.name}', author_id={self.author_id}, uses={self.uses}, timestamp='{self.timestamp}')>"
class Permission(Base):
__tablename__ = "permissions"
pk = Column(Integer, primary_key=True)
guild_id = Column(BigInteger)
channel_id = Column(BigInteger)
role_id = Column(BigInteger)
user = relationship("User", back_populates="permissions")
user_id = Column(BigInteger, ForeignKey("users.id"))
bits = Column(BigInteger)
def __repr__(self):
return f"<Permission(guild_id={self.guild_id}, channel_id={self.channel_id}, user_id={self.user_id}, bits={self.bits})>"
class User(Base):
__tablename__ = "users"
id = Column(BigInteger, primary_key=True, unique=True)
name = Column(String(40))
bot = Column(Boolean)
discrim = Column(String(4))
permissions = relationship("Permission", back_populates="user", cascade="all, delete, delete-orphan")
messages = relationship("Message", back_populates="author", cascade="all, delete, delete-orphan")
tags = relationship("Tag", back_populates="author", cascade="all, delete, delete-orphan")
commands = relationship("Command", back_populates="user", cascade="all, delete, delete-orphan")
changed_messages = relationship("MessageChange", back_populates="author", cascade="all, delete, delete-orphan")
def __repr__(self):
return f"<User(id={self.id}, name='{self.name}', bot={self.bot}, discrim='{self.discrim}', permissions={self.permissions}, messages={self.messages}, tags={self.tags})>"
class Blacklist(Base):
__tablename__ = "blacklist"
pk = Column(Integer, primary_key=True)
guild_id = Column(BigInteger)
channel_id = Column(BigInteger)
role_id = Column(BigInteger)
user_id = Column(BigInteger)
data = Column(String)
def __repr__(self):
return f"<Blacklist(guild_id={self.guild_id}, channel_id={self.channel_id}, role_id={self.role_id}, user_id={self.user_id})>"
class Whitelist(Base):
__tablename__ = "whitelist"
pk = Column(Integer, primary_key=True)
guild_id = Column(BigInteger)
channel_id = Column(BigInteger)
role_id = Column(BigInteger)
user_id = Column(BigInteger)
data = Column(String)
def __repr__(self):
return f"<Whitelist(guild_id={self.guild_id}, channel_id={self.channel_id}, role_id={self.role_id}, user_id={self.user_id})>"
class TagVariable(Base):
__tablename__ = "tag_values"
tag_name = Column(String(50), primary_key=True, unique=True)
data = Column(JSONB) # JSONb as key:value pairs
def __repr__(self):
return f"<TagVariable(tag_name='{self.tag_name}, values={self.data})>"
class Message(Base):
__tablename__ = "chat_logs"
class SQL:
def __init__(self, *args, **kwargs):
self.loop = asyncio.get_event_loop()
id = Column(BigInteger, primary_key=True, unique=True)
timestamp = Column(DateTime)
author_id = Column(BigInteger, ForeignKey("users.id"))
author = relationship("User", back_populates="messages")
command = relationship("Command", back_populates="message")
channel_id = Column(BigInteger)
guild_id = Column(BigInteger)
content = Column(String(2000))
db_name = kwargs.get("db_name")
db_username = kwargs.get("db_username")