From 0867960b134680205946bdc05713d07f89f47785 Mon Sep 17 00:00:00 2001 From: Lars-Dominik Braun Date: Tue, 2 Oct 2018 19:23:09 +0200 Subject: irc: Refactoring/beautification Add logging, split bot into abstract bot implementation and actual chromebot implementation, move some reusable checks into decorators. --- crocoite/irc.py | 358 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 260 insertions(+), 98 deletions(-) (limited to 'crocoite/irc.py') diff --git a/crocoite/irc.py b/crocoite/irc.py index d2eda45..878bf5e 100644 --- a/crocoite/irc.py +++ b/crocoite/irc.py @@ -25,7 +25,10 @@ IRC bot “chromebot” import asyncio, argparse, uuid, json, tempfile from datetime import datetime from urllib.parse import urlsplit -from enum import IntEnum +from enum import IntEnum, Enum +from collections import defaultdict +from abc import abstractmethod +from functools import wraps import bottom ### helper functions ### @@ -59,17 +62,21 @@ def isValidUrl (s): class NonExitingArgumentParser (argparse.ArgumentParser): """ Argument parser that does not call exit(), suitable for interactive use """ + def exit (self, status=0, message=None): # should never be called pass def error (self, message): - raise Exception (message) + # if we use subparsers it’s important to return self, so we can show + # the correct help + raise Exception (self, message) def format_usage (self): return super().format_usage ().replace ('\n', ' ') class Status(IntEnum): + """ Job status """ undefined = 0 pending = 1 running = 2 @@ -77,6 +84,8 @@ class Status(IntEnum): finished = 4 class Job: + """ Archival job """ + __slots__ = ('id', 'stats', 'rstats', 'started', 'finished', 'nick', 'status', 'process', 'url') def __init__ (self, url, nick): @@ -105,114 +114,87 @@ class Job: stats.get ('failed', 0), prettyBytes (stats.get ('bytesRcv', 0))) -class Bot(bottom.Client): - __slots__ = ('jobs', 'channels', 'nick', 'tempdir', 'destdir', 'parser', 'processLimit') +class NickMode(Enum): + operator = '@' + voice = '+' - def __init__ (self, host, port, ssl, nick, channels=[], - tempdir=tempfile.gettempdir(), destdir='.', processLimit=1): - super().__init__ (host=host, port=port, ssl=ssl) - self.jobs = {} - self.channels = channels - self.nick = nick - self.tempdir = tempdir - self.destdir = destdir - self.processLimit = asyncio.Semaphore (processLimit) + @classmethod + def fromMode (cls, mode): + return {'v': cls.voice, 'o': cls.operator}[mode] - self.parser = NonExitingArgumentParser (prog=self.nick + ': ', add_help=False) - subparsers = self.parser.add_subparsers(help='Sub-commands') +class User: + """ IRC user """ + __slots__ = ('name', 'modes') - archiveparser = subparsers.add_parser('a', help='Archive a site') - #archiveparser.add_argument('--timeout', default=1*60*60, type=int, help='Maximum time for archival', metavar='SEC', choices=[60, 1*60*60, 2*60*60]) - #archiveparser.add_argument('--idle-timeout', default=10, type=int, help='Maximum idle seconds (i.e. no requests)', dest='idleTimeout', metavar='SEC', choices=[1, 10, 20, 30, 60]) - #archiveparser.add_argument('--max-body-size', default=None, type=int, dest='maxBodySize', help='Max body size', metavar='BYTES', choices=[1*1024*1024, 10*1024*1024, 100*1024*1024]) - archiveparser.add_argument('--concurrency', '-j', default=1, type=int, help='Parallel workers for this job', choices=range (9)) - archiveparser.add_argument('--recursive', '-r', help='Enable recursion', choices=['0', '1', 'prefix'], default='0') - archiveparser.add_argument('url', help='Website URL', type=isValidUrl) - archiveparser.set_defaults (func=self.handleArchive) + def __init__ (self, name, modes=set ()): + self.name = name + self.modes = modes - statusparser = subparsers.add_parser ('s', help='Get job status') - statusparser.add_argument('id', help='Job id', metavar='UUID') - statusparser.set_defaults (func=self.handleStatus) + def __eq__ (self, b): + return self.name == b.name - abortparser = subparsers.add_parser ('r', help='Revoke/abort job') - abortparser.add_argument('id', help='Job id', metavar='UUID') - abortparser.set_defaults (func=self.handleAbort) - - # register bottom event handler - self.on('CLIENT_CONNECT', self.onConnect) - self.on('PING', self.onKeepalive) - self.on('PRIVMSG', self.onMessage) - self.on('CLIENT_DISCONNECT', self.onDisconnect) + def __hash__ (self): + return hash (self.name) - async def handleArchive (self, args, nick, target, message, **kwargs): - """ Handle the archive command """ - - j = Job (args.url, nick) - assert j.id not in self.jobs, 'duplicate job id' - self.jobs[j.id] = j - - cmdline = ['crocoite-recursive', args.url, '--tempdir', self.tempdir, - '--prefix', j.id + '-{host}-{date}-', '--policy', - args.recursive, '--concurrency', str (args.concurrency), - self.destdir] + def __repr__ (self): + return ''.format (self.name, self.modes) - showargs = { - 'recursive': args.recursive, - 'concurrency': args.concurrency, - } - strargs = ', '.join (map (lambda x: '{}={}'.format (*x), showargs.items ())) - self.send ('PRIVMSG', target=target, message='{}: {} has been queued as {} with {}'.format ( - nick, args.url, j.id, strargs)) + @classmethod + def fromName (cls, name): + """ Get mode and name from NAMES command """ + try: + modes = {NickMode(name[0])} + name = name[1:] + except ValueError: + modes = set () + return cls (name, modes) - async with self.processLimit: - j.process = await asyncio.create_subprocess_exec (*cmdline, stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.DEVNULL, stdin=asyncio.subprocess.DEVNULL) - while True: - data = await j.process.stdout.readline () - if not data: - break - - # job is marked running after the first message is received from it - if j.status == Status.pending: - j.status = Status.running - - data = json.loads (data) - msgid = data.get ('uuid') - if msgid == '24d92d16-770e-4088-b769-4020e127a7ff': - j.stats = data - elif msgid == '5b8498e4-868d-413c-a67e-004516b8452c': - j.rstats = data - code = await j.process.wait () +class ReplyContext: + __slots__ = ('client', 'target', 'user') - if j.status == Status.running: - j.status = Status.finished - j.finished = datetime.utcnow () + def __init__ (self, client, target, user): + self.client = client + self.target = target + self.user = user - stats = j.stats - rstats = j.rstats - self.send ('PRIVMSG', target=target, message='{}: {}'.format (nick, j.formatStatus ())) + def __call__ (self, message): + self.client.send ('PRIVMSG', target=self.target, message='{}: {}'.format (self.user.name, message)) - async def handleStatus (self, args, nick, target, message, **kwargs): - """ Handle status command """ +class ArgparseBot (bottom.Client): + """ + Simple IRC bot using argparse + + Tracks user’s modes, reconnects on disconnect + """ - j = self.jobs.get (args.id, None) - if not j: - self.send ('PRIVMSG', target=target, message='{}: Job {} is unknown'.format (nick, args.id)) - else: - rstats = j.rstats - self.send ('PRIVMSG', target=target, message='{}: {}'.format (nick, j.formatStatus ())) + __slots__ = ('channels', 'nick', 'parser', 'users') - async def handleAbort (self, args, nick, target, message, **kwargs): - """ Handle abort command """ + def __init__ (self, host, port, ssl, nick, logger, channels=[]): + super().__init__ (host=host, port=port, ssl=ssl) + self.channels = channels + self.nick = nick + # map channel -> nick -> user + self.users = defaultdict (dict) + self.logger = logger + self.parser = self.getParser () - j = self.jobs.get (args.id, None) - if not j: - self.send ('PRIVMSG', target=target, message='{}: Job {} is unknown'.format (nick, args.id)) - else: - j.status = Status.aborted - j.process.terminate () + # register bottom event handler + self.on('CLIENT_CONNECT', self.onConnect) + self.on('PING', self.onKeepalive) + self.on('PRIVMSG', self.onMessage) + self.on('CLIENT_DISCONNECT', self.onDisconnect) + self.on('RPL_NAMREPLY', self.onNameReply) + self.on('CHANNELMODE', self.onMode) + self.on('PART', self.onPart) + self.on('JOIN', self.onJoin) + # XXX: we would like to handle KICK, but bottom does not support that at the moment + + @abstractmethod + def getParser (self): + pass async def onConnect (self, **kwargs): + self.logger.info ('connect', nick=self.nick) self.send('NICK', nick=self.nick) self.send('USER', user=self.nick, realname='https://github.com/PromyLOPh/crocoite') @@ -228,7 +210,40 @@ class Bot(bottom.Client): future.cancel() for c in self.channels: - self.send('JOIN', channel=c) + self.logger.info ('join', channel=c) + self.send ('JOIN', channel=c) + # no need for NAMES here, server sends this automatically + + async def onNameReply (self, target, channel_type, channel, users, **kwargs): + self.users[channel] = dict (map (lambda x: (x.name, x), map (User.fromName, users))) + + async def onMode (self, nick, user, host, channel, modes, params, **kwargs): + if channel not in self.channels: + return + + op = modes[0] + for m, nick in zip (map (NickMode.fromMode, modes[1:]), params): + u = self.users[channel].get (nick, User (nick)) + if op == '+': + u.modes.add (m) + elif op == '-': + u.modes.remove (m) + + async def onPart (self, nick, user, host, message, channel, **kwargs): + if channel not in self.channels: + return + + try: + self.users[channel].pop (nick) + except KeyError: + # gone already + pass + + async def onJoin (self, nick, channel, **kwargs): + if channel not in self.channels: + return + + self.users[channel][nick] = User (nick) async def onKeepalive (self, message, **kwargs): """ Ping received """ @@ -237,21 +252,168 @@ class Bot(bottom.Client): async def onMessage (self, nick, target, message, **kwargs): """ Message received """ if target in self.channels and message.startswith (self.nick): + user = self.users[target].get (nick, User (nick)) + reply = ReplyContext (client=self, target=target, user=user) + # channel message that starts with our nick command = message.split (' ')[1:] try: args = self.parser.parse_args (command) except Exception as e: - self.send ('PRIVMSG', target=target, message='{} -- {}'.format (e.args[0], self.parser.format_usage ())) + reply ('{} -- {}'.format (e.args[1], e.args[0].format_usage ())) return if not args: - self.send ('PRIVMSG', target=target, message='Sorry, I don’t understand {}'.format (command)) + reply ('Sorry, I don’t understand {}'.format (command)) return - await args.func (args, nick, target, message, **kwargs) + await args.func (user=user, args=args, reply=reply) async def onDisconnect (**kwargs): """ Auto-reconnect """ + self.logger.info ('disconnect') await asynio.sleep (10, loop=self.loop) + self.logger.info ('reconnect') await self.connect () +def voice (func): + """ Calling user must have voice or ops """ + @wraps (func) + async def inner (self, *args, **kwargs): + user = kwargs.get ('user') + reply = kwargs.get ('reply') + if not user.modes.intersection ({NickMode.operator, NickMode.voice}): + reply ('Sorry, you must have voice to use this command.') + else: + ret = await func (self, *args, **kwargs) + return ret + return inner + +def jobExists (func): + """ Chromebot job exists """ + @wraps (func) + async def inner (self, **kwargs): + # XXX: not sure why it works with **kwargs, but not (user, args, reply) + args = kwargs.get ('args') + reply = kwargs.get ('reply') + j = self.jobs.get (args.id, None) + if not j: + reply ('Job {} is unknown'.format (args.id)) + else: + ret = await func (self, job=j, **kwargs) + return ret + return inner + +class Chromebot (ArgparseBot): + __slots__ = ('jobs', 'tempdir', 'destdir', 'processLimit') + + def __init__ (self, host, port, ssl, nick, logger, channels=[], + tempdir=tempfile.gettempdir(), destdir='.', processLimit=1): + super().__init__ (host=host, port=port, ssl=ssl, nick=nick, + logger=logger, channels=channels) + + self.jobs = {} + self.tempdir = tempdir + self.destdir = destdir + self.processLimit = asyncio.Semaphore (processLimit) + + def getParser (self): + parser = NonExitingArgumentParser (prog=self.nick + ': ', add_help=False) + subparsers = parser.add_subparsers(help='Sub-commands') + + archiveparser = subparsers.add_parser('a', help='Archive a site', add_help=False) + #archiveparser.add_argument('--timeout', default=1*60*60, type=int, help='Maximum time for archival', metavar='SEC', choices=[60, 1*60*60, 2*60*60]) + #archiveparser.add_argument('--idle-timeout', default=10, type=int, help='Maximum idle seconds (i.e. no requests)', dest='idleTimeout', metavar='SEC', choices=[1, 10, 20, 30, 60]) + #archiveparser.add_argument('--max-body-size', default=None, type=int, dest='maxBodySize', help='Max body size', metavar='BYTES', choices=[1*1024*1024, 10*1024*1024, 100*1024*1024]) + archiveparser.add_argument('--concurrency', '-j', default=1, type=int, help='Parallel workers for this job', choices=range (1, 5)) + archiveparser.add_argument('--recursive', '-r', help='Enable recursion', choices=['0', '1', 'prefix'], default='0') + archiveparser.add_argument('url', help='Website URL', type=isValidUrl, metavar='URL') + archiveparser.set_defaults (func=self.handleArchive) + + statusparser = subparsers.add_parser ('s', help='Get job status', add_help=False) + statusparser.add_argument('id', help='Job id', metavar='UUID') + statusparser.set_defaults (func=self.handleStatus) + + abortparser = subparsers.add_parser ('r', help='Revoke/abort job', add_help=False) + abortparser.add_argument('id', help='Job id', metavar='UUID') + abortparser.set_defaults (func=self.handleAbort) + + return parser + + @voice + async def handleArchive (self, user, args, reply): + """ Handle the archive command """ + + j = Job (args.url, user.name) + assert j.id not in self.jobs, 'duplicate job id' + self.jobs[j.id] = j + + logger = self.logger.bind (id=j.id, user=user.name, url=args.url) + + cmdline = ['crocoite-recursive', args.url, '--tempdir', self.tempdir, + '--prefix', j.id + '-{host}-{date}-', '--policy', + args.recursive, '--concurrency', str (args.concurrency), + self.destdir] + + showargs = { + 'recursive': args.recursive, + 'concurrency': args.concurrency, + } + strargs = ', '.join (map (lambda x: '{}={}'.format (*x), showargs.items ())) + reply ('{} has been queued as {} with {}'.format (args.url, j.id, strargs)) + logger.info ('queue', cmdline=cmdline) + + async with self.processLimit: + if j.status == Status.pending: + # job was not aborted + j.process = await asyncio.create_subprocess_exec (*cmdline, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + stdin=asyncio.subprocess.DEVNULL) + while True: + data = await j.process.stdout.readline () + if not data: + break + + # job is marked running after the first message is received from it + if j.status == Status.pending: + logger.info ('start') + j.status = Status.running + + data = json.loads (data) + msgid = data.get ('uuid') + if msgid == '24d92d16-770e-4088-b769-4020e127a7ff': + j.stats = data + elif msgid == '5b8498e4-868d-413c-a67e-004516b8452c': + j.rstats = data + code = await j.process.wait () + + if j.status == Status.running: + logger.info ('finish') + j.status = Status.finished + j.finished = datetime.utcnow () + + stats = j.stats + rstats = j.rstats + reply (j.formatStatus ()) + + @jobExists + async def handleStatus (self, user, args, reply, job): + """ Handle status command """ + + rstats = job.rstats + reply (job.formatStatus ()) + + @voice + @jobExists + async def handleAbort (self, user, args, reply, job): + """ Handle abort command """ + + if job.status not in {Status.pending, Status.running}: + reply ('This job is not running.') + return + + job.status = Status.aborted + self.logger.info ('abort', id=job.id, user=user.name) + if job.process and job.process.returncode is None: + job.process.terminate () + -- cgit v1.2.3