summaryrefslogtreecommitdiff
path: root/crocoite
diff options
context:
space:
mode:
authorLars-Dominik Braun <lars@6xq.net>2018-10-02 19:23:09 +0200
committerLars-Dominik Braun <lars@6xq.net>2018-10-02 19:24:40 +0200
commit0867960b134680205946bdc05713d07f89f47785 (patch)
tree4cde89e5a4475031b6f46e736281fb911a73d3b2 /crocoite
parent07c34b2d004f16798c17ed479679a511c6bd2f29 (diff)
downloadcrocoite-0867960b134680205946bdc05713d07f89f47785.tar.gz
crocoite-0867960b134680205946bdc05713d07f89f47785.tar.bz2
crocoite-0867960b134680205946bdc05713d07f89f47785.zip
irc: Refactoring/beautification
Add logging, split bot into abstract bot implementation and actual chromebot implementation, move some reusable checks into decorators.
Diffstat (limited to 'crocoite')
-rw-r--r--crocoite/cli.py9
-rw-r--r--crocoite/irc.py358
2 files changed, 266 insertions, 101 deletions
diff --git a/crocoite/cli.py b/crocoite/cli.py
index 63199c9..0319dc9 100644
--- a/crocoite/cli.py
+++ b/crocoite/cli.py
@@ -109,7 +109,9 @@ def recursive ():
def irc ():
from configparser import ConfigParser
- from .irc import Bot
+ from .irc import Chromebot
+
+ logger = Logger (consumer=[DatetimeConsumer (), JsonPrintConsumer ()])
parser = argparse.ArgumentParser(description='IRC bot.')
parser.add_argument('--config', '-c', help='Config file location', metavar='PATH', default='chromebot.ini')
@@ -120,7 +122,7 @@ def irc ():
config.read (args.config)
s = config['irc']
- bot = Bot (
+ bot = Chromebot (
host=s.get ('host'),
port=s.getint ('port'),
ssl=s.getboolean ('ssl'),
@@ -128,7 +130,8 @@ def irc ():
channels=[s.get ('channel')],
tempdir=s.get ('tempdir'),
destdir=s.get ('destdir'),
- processLimit=s.getint ('process_limit'))
+ processLimit=s.getint ('process_limit'),
+ logger=logger)
bot.loop.create_task(bot.connect())
bot.loop.run_forever()
diff --git a/crocoite/irc.py b/crocoite/irc.py
index d2eda45..878bf5e 100644
--- a/crocoite/irc.py
+++ b/crocoite/irc.py
@@ -25,7 +25,10 @@ IRC bot “chromebot”
import asyncio, argparse, uuid, json, tempfile
from datetime import datetime
from urllib.parse import urlsplit
-from enum import IntEnum
+from enum import IntEnum, Enum
+from collections import defaultdict
+from abc import abstractmethod
+from functools import wraps
import bottom
### helper functions ###
@@ -59,17 +62,21 @@ def isValidUrl (s):
class NonExitingArgumentParser (argparse.ArgumentParser):
""" Argument parser that does not call exit(), suitable for interactive use """
+
def exit (self, status=0, message=None):
# should never be called
pass
def error (self, message):
- raise Exception (message)
+ # if we use subparsers it’s important to return self, so we can show
+ # the correct help
+ raise Exception (self, message)
def format_usage (self):
return super().format_usage ().replace ('\n', ' ')
class Status(IntEnum):
+ """ Job status """
undefined = 0
pending = 1
running = 2
@@ -77,6 +84,8 @@ class Status(IntEnum):
finished = 4
class Job:
+ """ Archival job """
+
__slots__ = ('id', 'stats', 'rstats', 'started', 'finished', 'nick', 'status', 'process', 'url')
def __init__ (self, url, nick):
@@ -105,114 +114,87 @@ class Job:
stats.get ('failed', 0),
prettyBytes (stats.get ('bytesRcv', 0)))
-class Bot(bottom.Client):
- __slots__ = ('jobs', 'channels', 'nick', 'tempdir', 'destdir', 'parser', 'processLimit')
+class NickMode(Enum):
+ operator = '@'
+ voice = '+'
- def __init__ (self, host, port, ssl, nick, channels=[],
- tempdir=tempfile.gettempdir(), destdir='.', processLimit=1):
- super().__init__ (host=host, port=port, ssl=ssl)
- self.jobs = {}
- self.channels = channels
- self.nick = nick
- self.tempdir = tempdir
- self.destdir = destdir
- self.processLimit = asyncio.Semaphore (processLimit)
+ @classmethod
+ def fromMode (cls, mode):
+ return {'v': cls.voice, 'o': cls.operator}[mode]
- self.parser = NonExitingArgumentParser (prog=self.nick + ': ', add_help=False)
- subparsers = self.parser.add_subparsers(help='Sub-commands')
+class User:
+ """ IRC user """
+ __slots__ = ('name', 'modes')
- archiveparser = subparsers.add_parser('a', help='Archive a site')
- #archiveparser.add_argument('--timeout', default=1*60*60, type=int, help='Maximum time for archival', metavar='SEC', choices=[60, 1*60*60, 2*60*60])
- #archiveparser.add_argument('--idle-timeout', default=10, type=int, help='Maximum idle seconds (i.e. no requests)', dest='idleTimeout', metavar='SEC', choices=[1, 10, 20, 30, 60])
- #archiveparser.add_argument('--max-body-size', default=None, type=int, dest='maxBodySize', help='Max body size', metavar='BYTES', choices=[1*1024*1024, 10*1024*1024, 100*1024*1024])
- archiveparser.add_argument('--concurrency', '-j', default=1, type=int, help='Parallel workers for this job', choices=range (9))
- archiveparser.add_argument('--recursive', '-r', help='Enable recursion', choices=['0', '1', 'prefix'], default='0')
- archiveparser.add_argument('url', help='Website URL', type=isValidUrl)
- archiveparser.set_defaults (func=self.handleArchive)
+ def __init__ (self, name, modes=set ()):
+ self.name = name
+ self.modes = modes
- statusparser = subparsers.add_parser ('s', help='Get job status')
- statusparser.add_argument('id', help='Job id', metavar='UUID')
- statusparser.set_defaults (func=self.handleStatus)
+ def __eq__ (self, b):
+ return self.name == b.name
- abortparser = subparsers.add_parser ('r', help='Revoke/abort job')
- abortparser.add_argument('id', help='Job id', metavar='UUID')
- abortparser.set_defaults (func=self.handleAbort)
-
- # register bottom event handler
- self.on('CLIENT_CONNECT', self.onConnect)
- self.on('PING', self.onKeepalive)
- self.on('PRIVMSG', self.onMessage)
- self.on('CLIENT_DISCONNECT', self.onDisconnect)
+ def __hash__ (self):
+ return hash (self.name)
- async def handleArchive (self, args, nick, target, message, **kwargs):
- """ Handle the archive command """
-
- j = Job (args.url, nick)
- assert j.id not in self.jobs, 'duplicate job id'
- self.jobs[j.id] = j
-
- cmdline = ['crocoite-recursive', args.url, '--tempdir', self.tempdir,
- '--prefix', j.id + '-{host}-{date}-', '--policy',
- args.recursive, '--concurrency', str (args.concurrency),
- self.destdir]
+ def __repr__ (self):
+ return '<User {} {}>'.format (self.name, self.modes)
- showargs = {
- 'recursive': args.recursive,
- 'concurrency': args.concurrency,
- }
- strargs = ', '.join (map (lambda x: '{}={}'.format (*x), showargs.items ()))
- self.send ('PRIVMSG', target=target, message='{}: {} has been queued as {} with {}'.format (
- nick, args.url, j.id, strargs))
+ @classmethod
+ def fromName (cls, name):
+ """ Get mode and name from NAMES command """
+ try:
+ modes = {NickMode(name[0])}
+ name = name[1:]
+ except ValueError:
+ modes = set ()
+ return cls (name, modes)
- async with self.processLimit:
- j.process = await asyncio.create_subprocess_exec (*cmdline, stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.DEVNULL, stdin=asyncio.subprocess.DEVNULL)
- while True:
- data = await j.process.stdout.readline ()
- if not data:
- break
-
- # job is marked running after the first message is received from it
- if j.status == Status.pending:
- j.status = Status.running
-
- data = json.loads (data)
- msgid = data.get ('uuid')
- if msgid == '24d92d16-770e-4088-b769-4020e127a7ff':
- j.stats = data
- elif msgid == '5b8498e4-868d-413c-a67e-004516b8452c':
- j.rstats = data
- code = await j.process.wait ()
+class ReplyContext:
+ __slots__ = ('client', 'target', 'user')
- if j.status == Status.running:
- j.status = Status.finished
- j.finished = datetime.utcnow ()
+ def __init__ (self, client, target, user):
+ self.client = client
+ self.target = target
+ self.user = user
- stats = j.stats
- rstats = j.rstats
- self.send ('PRIVMSG', target=target, message='{}: {}'.format (nick, j.formatStatus ()))
+ def __call__ (self, message):
+ self.client.send ('PRIVMSG', target=self.target, message='{}: {}'.format (self.user.name, message))
- async def handleStatus (self, args, nick, target, message, **kwargs):
- """ Handle status command """
+class ArgparseBot (bottom.Client):
+ """
+ Simple IRC bot using argparse
+
+ Tracks user’s modes, reconnects on disconnect
+ """
- j = self.jobs.get (args.id, None)
- if not j:
- self.send ('PRIVMSG', target=target, message='{}: Job {} is unknown'.format (nick, args.id))
- else:
- rstats = j.rstats
- self.send ('PRIVMSG', target=target, message='{}: {}'.format (nick, j.formatStatus ()))
+ __slots__ = ('channels', 'nick', 'parser', 'users')
- async def handleAbort (self, args, nick, target, message, **kwargs):
- """ Handle abort command """
+ def __init__ (self, host, port, ssl, nick, logger, channels=[]):
+ super().__init__ (host=host, port=port, ssl=ssl)
+ self.channels = channels
+ self.nick = nick
+ # map channel -> nick -> user
+ self.users = defaultdict (dict)
+ self.logger = logger
+ self.parser = self.getParser ()
- j = self.jobs.get (args.id, None)
- if not j:
- self.send ('PRIVMSG', target=target, message='{}: Job {} is unknown'.format (nick, args.id))
- else:
- j.status = Status.aborted
- j.process.terminate ()
+ # register bottom event handler
+ self.on('CLIENT_CONNECT', self.onConnect)
+ self.on('PING', self.onKeepalive)
+ self.on('PRIVMSG', self.onMessage)
+ self.on('CLIENT_DISCONNECT', self.onDisconnect)
+ self.on('RPL_NAMREPLY', self.onNameReply)
+ self.on('CHANNELMODE', self.onMode)
+ self.on('PART', self.onPart)
+ self.on('JOIN', self.onJoin)
+ # XXX: we would like to handle KICK, but bottom does not support that at the moment
+
+ @abstractmethod
+ def getParser (self):
+ pass
async def onConnect (self, **kwargs):
+ self.logger.info ('connect', nick=self.nick)
self.send('NICK', nick=self.nick)
self.send('USER', user=self.nick, realname='https://github.com/PromyLOPh/crocoite')
@@ -228,7 +210,40 @@ class Bot(bottom.Client):
future.cancel()
for c in self.channels:
- self.send('JOIN', channel=c)
+ self.logger.info ('join', channel=c)
+ self.send ('JOIN', channel=c)
+ # no need for NAMES here, server sends this automatically
+
+ async def onNameReply (self, target, channel_type, channel, users, **kwargs):
+ self.users[channel] = dict (map (lambda x: (x.name, x), map (User.fromName, users)))
+
+ async def onMode (self, nick, user, host, channel, modes, params, **kwargs):
+ if channel not in self.channels:
+ return
+
+ op = modes[0]
+ for m, nick in zip (map (NickMode.fromMode, modes[1:]), params):
+ u = self.users[channel].get (nick, User (nick))
+ if op == '+':
+ u.modes.add (m)
+ elif op == '-':
+ u.modes.remove (m)
+
+ async def onPart (self, nick, user, host, message, channel, **kwargs):
+ if channel not in self.channels:
+ return
+
+ try:
+ self.users[channel].pop (nick)
+ except KeyError:
+ # gone already
+ pass
+
+ async def onJoin (self, nick, channel, **kwargs):
+ if channel not in self.channels:
+ return
+
+ self.users[channel][nick] = User (nick)
async def onKeepalive (self, message, **kwargs):
""" Ping received """
@@ -237,21 +252,168 @@ class Bot(bottom.Client):
async def onMessage (self, nick, target, message, **kwargs):
""" Message received """
if target in self.channels and message.startswith (self.nick):
+ user = self.users[target].get (nick, User (nick))
+ reply = ReplyContext (client=self, target=target, user=user)
+
# channel message that starts with our nick
command = message.split (' ')[1:]
try:
args = self.parser.parse_args (command)
except Exception as e:
- self.send ('PRIVMSG', target=target, message='{} -- {}'.format (e.args[0], self.parser.format_usage ()))
+ reply ('{} -- {}'.format (e.args[1], e.args[0].format_usage ()))
return
if not args:
- self.send ('PRIVMSG', target=target, message='Sorry, I don’t understand {}'.format (command))
+ reply ('Sorry, I don’t understand {}'.format (command))
return
- await args.func (args, nick, target, message, **kwargs)
+ await args.func (user=user, args=args, reply=reply)
async def onDisconnect (**kwargs):
""" Auto-reconnect """
+ self.logger.info ('disconnect')
await asynio.sleep (10, loop=self.loop)
+ self.logger.info ('reconnect')
await self.connect ()
+def voice (func):
+ """ Calling user must have voice or ops """
+ @wraps (func)
+ async def inner (self, *args, **kwargs):
+ user = kwargs.get ('user')
+ reply = kwargs.get ('reply')
+ if not user.modes.intersection ({NickMode.operator, NickMode.voice}):
+ reply ('Sorry, you must have voice to use this command.')
+ else:
+ ret = await func (self, *args, **kwargs)
+ return ret
+ return inner
+
+def jobExists (func):
+ """ Chromebot job exists """
+ @wraps (func)
+ async def inner (self, **kwargs):
+ # XXX: not sure why it works with **kwargs, but not (user, args, reply)
+ args = kwargs.get ('args')
+ reply = kwargs.get ('reply')
+ j = self.jobs.get (args.id, None)
+ if not j:
+ reply ('Job {} is unknown'.format (args.id))
+ else:
+ ret = await func (self, job=j, **kwargs)
+ return ret
+ return inner
+
+class Chromebot (ArgparseBot):
+ __slots__ = ('jobs', 'tempdir', 'destdir', 'processLimit')
+
+ def __init__ (self, host, port, ssl, nick, logger, channels=[],
+ tempdir=tempfile.gettempdir(), destdir='.', processLimit=1):
+ super().__init__ (host=host, port=port, ssl=ssl, nick=nick,
+ logger=logger, channels=channels)
+
+ self.jobs = {}
+ self.tempdir = tempdir
+ self.destdir = destdir
+ self.processLimit = asyncio.Semaphore (processLimit)
+
+ def getParser (self):
+ parser = NonExitingArgumentParser (prog=self.nick + ': ', add_help=False)
+ subparsers = parser.add_subparsers(help='Sub-commands')
+
+ archiveparser = subparsers.add_parser('a', help='Archive a site', add_help=False)
+ #archiveparser.add_argument('--timeout', default=1*60*60, type=int, help='Maximum time for archival', metavar='SEC', choices=[60, 1*60*60, 2*60*60])
+ #archiveparser.add_argument('--idle-timeout', default=10, type=int, help='Maximum idle seconds (i.e. no requests)', dest='idleTimeout', metavar='SEC', choices=[1, 10, 20, 30, 60])
+ #archiveparser.add_argument('--max-body-size', default=None, type=int, dest='maxBodySize', help='Max body size', metavar='BYTES', choices=[1*1024*1024, 10*1024*1024, 100*1024*1024])
+ archiveparser.add_argument('--concurrency', '-j', default=1, type=int, help='Parallel workers for this job', choices=range (1, 5))
+ archiveparser.add_argument('--recursive', '-r', help='Enable recursion', choices=['0', '1', 'prefix'], default='0')
+ archiveparser.add_argument('url', help='Website URL', type=isValidUrl, metavar='URL')
+ archiveparser.set_defaults (func=self.handleArchive)
+
+ statusparser = subparsers.add_parser ('s', help='Get job status', add_help=False)
+ statusparser.add_argument('id', help='Job id', metavar='UUID')
+ statusparser.set_defaults (func=self.handleStatus)
+
+ abortparser = subparsers.add_parser ('r', help='Revoke/abort job', add_help=False)
+ abortparser.add_argument('id', help='Job id', metavar='UUID')
+ abortparser.set_defaults (func=self.handleAbort)
+
+ return parser
+
+ @voice
+ async def handleArchive (self, user, args, reply):
+ """ Handle the archive command """
+
+ j = Job (args.url, user.name)
+ assert j.id not in self.jobs, 'duplicate job id'
+ self.jobs[j.id] = j
+
+ logger = self.logger.bind (id=j.id, user=user.name, url=args.url)
+
+ cmdline = ['crocoite-recursive', args.url, '--tempdir', self.tempdir,
+ '--prefix', j.id + '-{host}-{date}-', '--policy',
+ args.recursive, '--concurrency', str (args.concurrency),
+ self.destdir]
+
+ showargs = {
+ 'recursive': args.recursive,
+ 'concurrency': args.concurrency,
+ }
+ strargs = ', '.join (map (lambda x: '{}={}'.format (*x), showargs.items ()))
+ reply ('{} has been queued as {} with {}'.format (args.url, j.id, strargs))
+ logger.info ('queue', cmdline=cmdline)
+
+ async with self.processLimit:
+ if j.status == Status.pending:
+ # job was not aborted
+ j.process = await asyncio.create_subprocess_exec (*cmdline,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.DEVNULL,
+ stdin=asyncio.subprocess.DEVNULL)
+ while True:
+ data = await j.process.stdout.readline ()
+ if not data:
+ break
+
+ # job is marked running after the first message is received from it
+ if j.status == Status.pending:
+ logger.info ('start')
+ j.status = Status.running
+
+ data = json.loads (data)
+ msgid = data.get ('uuid')
+ if msgid == '24d92d16-770e-4088-b769-4020e127a7ff':
+ j.stats = data
+ elif msgid == '5b8498e4-868d-413c-a67e-004516b8452c':
+ j.rstats = data
+ code = await j.process.wait ()
+
+ if j.status == Status.running:
+ logger.info ('finish')
+ j.status = Status.finished
+ j.finished = datetime.utcnow ()
+
+ stats = j.stats
+ rstats = j.rstats
+ reply (j.formatStatus ())
+
+ @jobExists
+ async def handleStatus (self, user, args, reply, job):
+ """ Handle status command """
+
+ rstats = job.rstats
+ reply (job.formatStatus ())
+
+ @voice
+ @jobExists
+ async def handleAbort (self, user, args, reply, job):
+ """ Handle abort command """
+
+ if job.status not in {Status.pending, Status.running}:
+ reply ('This job is not running.')
+ return
+
+ job.status = Status.aborted
+ self.logger.info ('abort', id=job.id, user=user.name)
+ if job.process and job.process.returncode is None:
+ job.process.terminate ()
+