diff options
| -rw-r--r-- | crocoite/cli.py | 9 | ||||
| -rw-r--r-- | crocoite/irc.py | 358 | 
2 files changed, 266 insertions, 101 deletions
| diff --git a/crocoite/cli.py b/crocoite/cli.py index 63199c9..0319dc9 100644 --- a/crocoite/cli.py +++ b/crocoite/cli.py @@ -109,7 +109,9 @@ def recursive ():  def irc ():      from configparser import ConfigParser -    from .irc import Bot +    from .irc import Chromebot + +    logger = Logger (consumer=[DatetimeConsumer (), JsonPrintConsumer ()])      parser = argparse.ArgumentParser(description='IRC bot.')      parser.add_argument('--config', '-c', help='Config file location', metavar='PATH', default='chromebot.ini') @@ -120,7 +122,7 @@ def irc ():      config.read (args.config)      s = config['irc'] -    bot = Bot ( +    bot = Chromebot (              host=s.get ('host'),              port=s.getint ('port'),              ssl=s.getboolean ('ssl'), @@ -128,7 +130,8 @@ def irc ():              channels=[s.get ('channel')],              tempdir=s.get ('tempdir'),              destdir=s.get ('destdir'), -            processLimit=s.getint ('process_limit')) +            processLimit=s.getint ('process_limit'), +            logger=logger)      bot.loop.create_task(bot.connect())      bot.loop.run_forever() diff --git a/crocoite/irc.py b/crocoite/irc.py index d2eda45..878bf5e 100644 --- a/crocoite/irc.py +++ b/crocoite/irc.py @@ -25,7 +25,10 @@ IRC bot “chromebot”  import asyncio, argparse, uuid, json, tempfile  from datetime import datetime  from urllib.parse import urlsplit -from enum import IntEnum +from enum import IntEnum, Enum +from collections import defaultdict +from abc import abstractmethod +from functools import wraps  import bottom  ### helper functions ### @@ -59,17 +62,21 @@ def isValidUrl (s):  class NonExitingArgumentParser (argparse.ArgumentParser):      """ Argument parser that does not call exit(), suitable for interactive use """ +      def exit (self, status=0, message=None):          # should never be called          pass      def error (self, message): -        raise Exception (message) +        # if we use subparsers it’s important to return self, so we can show +        # the correct help +        raise Exception (self, message)      def format_usage (self):          return super().format_usage ().replace ('\n', ' ')  class Status(IntEnum): +    """ Job status """      undefined = 0      pending = 1      running = 2 @@ -77,6 +84,8 @@ class Status(IntEnum):      finished = 4  class Job: +    """ Archival job """ +      __slots__ = ('id', 'stats', 'rstats', 'started', 'finished', 'nick', 'status', 'process', 'url')      def __init__ (self, url, nick): @@ -105,114 +114,87 @@ class Job:                  stats.get ('failed', 0),                  prettyBytes (stats.get ('bytesRcv', 0))) -class Bot(bottom.Client): -    __slots__ = ('jobs', 'channels', 'nick', 'tempdir', 'destdir', 'parser', 'processLimit') +class NickMode(Enum): +    operator = '@' +    voice = '+' -    def __init__ (self, host, port, ssl, nick, channels=[], -            tempdir=tempfile.gettempdir(), destdir='.', processLimit=1): -        super().__init__ (host=host, port=port, ssl=ssl) -        self.jobs = {} -        self.channels = channels -        self.nick = nick -        self.tempdir = tempdir -        self.destdir = destdir -        self.processLimit = asyncio.Semaphore (processLimit) +    @classmethod +    def fromMode (cls, mode): +        return {'v': cls.voice, 'o': cls.operator}[mode] -        self.parser = NonExitingArgumentParser (prog=self.nick + ': ', add_help=False) -        subparsers = self.parser.add_subparsers(help='Sub-commands') +class User: +    """ IRC user """ +    __slots__ = ('name', 'modes') -        archiveparser = subparsers.add_parser('a', help='Archive a site') -        #archiveparser.add_argument('--timeout', default=1*60*60, type=int, help='Maximum time for archival', metavar='SEC', choices=[60, 1*60*60, 2*60*60]) -        #archiveparser.add_argument('--idle-timeout', default=10, type=int, help='Maximum idle seconds (i.e. no requests)', dest='idleTimeout', metavar='SEC', choices=[1, 10, 20, 30, 60]) -        #archiveparser.add_argument('--max-body-size', default=None, type=int, dest='maxBodySize', help='Max body size', metavar='BYTES', choices=[1*1024*1024, 10*1024*1024, 100*1024*1024]) -        archiveparser.add_argument('--concurrency', '-j', default=1, type=int, help='Parallel workers for this job', choices=range (9)) -        archiveparser.add_argument('--recursive', '-r', help='Enable recursion', choices=['0', '1', 'prefix'], default='0') -        archiveparser.add_argument('url', help='Website URL', type=isValidUrl) -        archiveparser.set_defaults (func=self.handleArchive) +    def __init__ (self, name, modes=set ()): +        self.name = name +        self.modes = modes -        statusparser = subparsers.add_parser ('s', help='Get job status') -        statusparser.add_argument('id', help='Job id', metavar='UUID') -        statusparser.set_defaults (func=self.handleStatus) +    def __eq__ (self, b): +        return self.name == b.name -        abortparser = subparsers.add_parser ('r', help='Revoke/abort job') -        abortparser.add_argument('id', help='Job id', metavar='UUID') -        abortparser.set_defaults (func=self.handleAbort) - -        # register bottom event handler -        self.on('CLIENT_CONNECT', self.onConnect) -        self.on('PING', self.onKeepalive) -        self.on('PRIVMSG', self.onMessage) -        self.on('CLIENT_DISCONNECT', self.onDisconnect) +    def __hash__ (self): +        return hash (self.name) -    async def handleArchive (self, args, nick, target, message, **kwargs): -        """ Handle the archive command """ - -        j = Job (args.url, nick) -        assert j.id not in self.jobs, 'duplicate job id' -        self.jobs[j.id] = j - -        cmdline = ['crocoite-recursive', args.url, '--tempdir', self.tempdir, -                '--prefix', j.id + '-{host}-{date}-', '--policy', -                args.recursive, '--concurrency', str (args.concurrency), -                self.destdir] +    def __repr__ (self): +        return '<User {} {}>'.format (self.name, self.modes) -        showargs = { -                'recursive': args.recursive, -                'concurrency': args.concurrency, -                } -        strargs = ', '.join (map (lambda x: '{}={}'.format (*x), showargs.items ())) -        self.send ('PRIVMSG', target=target, message='{}: {} has been queued as {} with {}'.format ( -                nick, args.url, j.id, strargs)) +    @classmethod +    def fromName (cls, name): +        """ Get mode and name from NAMES command """ +        try: +            modes = {NickMode(name[0])} +            name = name[1:] +        except ValueError: +            modes = set () +        return cls (name, modes) -        async with self.processLimit: -            j.process = await asyncio.create_subprocess_exec (*cmdline, stdout=asyncio.subprocess.PIPE, -                    stderr=asyncio.subprocess.DEVNULL, stdin=asyncio.subprocess.DEVNULL) -            while True: -                data = await j.process.stdout.readline () -                if not data: -                    break - -                # job is marked running after the first message is received from it -                if j.status == Status.pending: -                    j.status = Status.running - -                data = json.loads (data) -                msgid = data.get ('uuid') -                if msgid == '24d92d16-770e-4088-b769-4020e127a7ff': -                    j.stats = data -                elif msgid == '5b8498e4-868d-413c-a67e-004516b8452c': -                    j.rstats = data -            code = await j.process.wait () +class ReplyContext: +    __slots__ = ('client', 'target', 'user') -        if j.status == Status.running: -            j.status = Status.finished -        j.finished = datetime.utcnow () +    def __init__ (self, client, target, user): +        self.client = client +        self.target = target +        self.user = user -        stats = j.stats -        rstats = j.rstats -        self.send ('PRIVMSG', target=target, message='{}: {}'.format (nick, j.formatStatus ())) +    def __call__ (self, message): +        self.client.send ('PRIVMSG', target=self.target, message='{}: {}'.format (self.user.name, message)) -    async def handleStatus (self, args, nick, target, message, **kwargs): -        """ Handle status command """ +class ArgparseBot (bottom.Client): +    """ +    Simple IRC bot using argparse +     +    Tracks user’s modes, reconnects on disconnect +    """ -        j = self.jobs.get (args.id, None) -        if not j: -            self.send ('PRIVMSG', target=target, message='{}: Job {} is unknown'.format (nick, args.id)) -        else: -            rstats = j.rstats -            self.send ('PRIVMSG', target=target, message='{}: {}'.format (nick, j.formatStatus ())) +    __slots__ = ('channels', 'nick', 'parser', 'users') -    async def handleAbort (self, args, nick, target, message, **kwargs): -        """ Handle abort command """ +    def __init__ (self, host, port, ssl, nick, logger, channels=[]): +        super().__init__ (host=host, port=port, ssl=ssl) +        self.channels = channels +        self.nick = nick +        # map channel -> nick -> user +        self.users = defaultdict (dict) +        self.logger = logger +        self.parser = self.getParser () -        j = self.jobs.get (args.id, None) -        if not j: -            self.send ('PRIVMSG', target=target, message='{}: Job {} is unknown'.format (nick, args.id)) -        else: -            j.status = Status.aborted -            j.process.terminate () +        # register bottom event handler +        self.on('CLIENT_CONNECT', self.onConnect) +        self.on('PING', self.onKeepalive) +        self.on('PRIVMSG', self.onMessage) +        self.on('CLIENT_DISCONNECT', self.onDisconnect) +        self.on('RPL_NAMREPLY', self.onNameReply) +        self.on('CHANNELMODE', self.onMode) +        self.on('PART', self.onPart) +        self.on('JOIN', self.onJoin) +        # XXX: we would like to handle KICK, but bottom does not support that at the moment + +    @abstractmethod +    def getParser (self): +        pass      async def onConnect (self, **kwargs): +        self.logger.info ('connect', nick=self.nick)          self.send('NICK', nick=self.nick)          self.send('USER', user=self.nick, realname='https://github.com/PromyLOPh/crocoite') @@ -228,7 +210,40 @@ class Bot(bottom.Client):              future.cancel()          for c in self.channels: -            self.send('JOIN', channel=c) +            self.logger.info ('join', channel=c) +            self.send ('JOIN', channel=c) +            # no need for NAMES here, server sends this automatically + +    async def onNameReply (self, target, channel_type, channel, users, **kwargs): +        self.users[channel] = dict (map (lambda x: (x.name, x), map (User.fromName, users))) + +    async def onMode (self, nick, user, host, channel, modes, params, **kwargs): +        if channel not in self.channels: +            return + +        op = modes[0] +        for m, nick in zip (map (NickMode.fromMode, modes[1:]), params): +            u = self.users[channel].get (nick, User (nick)) +            if op == '+': +                u.modes.add (m) +            elif op == '-': +                u.modes.remove (m) + +    async def onPart (self, nick, user, host, message, channel, **kwargs): +        if channel not in self.channels: +            return + +        try: +            self.users[channel].pop (nick) +        except KeyError: +            # gone already +            pass + +    async def onJoin (self, nick, channel, **kwargs): +        if channel not in self.channels: +            return + +        self.users[channel][nick] = User (nick)      async def onKeepalive (self, message, **kwargs):          """ Ping received """ @@ -237,21 +252,168 @@ class Bot(bottom.Client):      async def onMessage (self, nick, target, message, **kwargs):          """ Message received """          if target in self.channels and message.startswith (self.nick): +            user = self.users[target].get (nick, User (nick)) +            reply = ReplyContext (client=self, target=target, user=user) +              # channel message that starts with our nick              command = message.split (' ')[1:]              try:                  args = self.parser.parse_args (command)              except Exception as e: -                self.send ('PRIVMSG', target=target, message='{} -- {}'.format (e.args[0], self.parser.format_usage ())) +                reply ('{} -- {}'.format (e.args[1], e.args[0].format_usage ()))                  return              if not args: -                self.send ('PRIVMSG', target=target, message='Sorry, I don’t understand {}'.format (command)) +                reply ('Sorry, I don’t understand {}'.format (command))                  return -            await args.func (args, nick, target, message, **kwargs) +            await args.func (user=user, args=args, reply=reply)      async def onDisconnect (**kwargs):          """ Auto-reconnect """ +        self.logger.info ('disconnect')          await asynio.sleep (10, loop=self.loop) +        self.logger.info ('reconnect')          await self.connect () +def voice (func): +    """ Calling user must have voice or ops """ +    @wraps (func) +    async def inner (self, *args, **kwargs): +        user = kwargs.get ('user') +        reply = kwargs.get ('reply') +        if not user.modes.intersection ({NickMode.operator, NickMode.voice}): +            reply ('Sorry, you must have voice to use this command.') +        else: +            ret = await func (self, *args, **kwargs) +            return ret +    return inner + +def jobExists (func): +    """ Chromebot job exists """ +    @wraps (func) +    async def inner (self, **kwargs): +        # XXX: not sure why it works with **kwargs, but not (user, args, reply) +        args = kwargs.get ('args') +        reply = kwargs.get ('reply') +        j = self.jobs.get (args.id, None) +        if not j: +            reply ('Job {} is unknown'.format (args.id)) +        else: +            ret = await func (self, job=j, **kwargs) +            return ret +    return inner + +class Chromebot (ArgparseBot): +    __slots__ = ('jobs', 'tempdir', 'destdir', 'processLimit') + +    def __init__ (self, host, port, ssl, nick, logger, channels=[], +            tempdir=tempfile.gettempdir(), destdir='.', processLimit=1): +        super().__init__ (host=host, port=port, ssl=ssl, nick=nick, +                logger=logger, channels=channels) + +        self.jobs = {} +        self.tempdir = tempdir +        self.destdir = destdir +        self.processLimit = asyncio.Semaphore (processLimit) + +    def getParser (self): +        parser = NonExitingArgumentParser (prog=self.nick + ': ', add_help=False) +        subparsers = parser.add_subparsers(help='Sub-commands') + +        archiveparser = subparsers.add_parser('a', help='Archive a site', add_help=False) +        #archiveparser.add_argument('--timeout', default=1*60*60, type=int, help='Maximum time for archival', metavar='SEC', choices=[60, 1*60*60, 2*60*60]) +        #archiveparser.add_argument('--idle-timeout', default=10, type=int, help='Maximum idle seconds (i.e. no requests)', dest='idleTimeout', metavar='SEC', choices=[1, 10, 20, 30, 60]) +        #archiveparser.add_argument('--max-body-size', default=None, type=int, dest='maxBodySize', help='Max body size', metavar='BYTES', choices=[1*1024*1024, 10*1024*1024, 100*1024*1024]) +        archiveparser.add_argument('--concurrency', '-j', default=1, type=int, help='Parallel workers for this job', choices=range (1, 5)) +        archiveparser.add_argument('--recursive', '-r', help='Enable recursion', choices=['0', '1', 'prefix'], default='0') +        archiveparser.add_argument('url', help='Website URL', type=isValidUrl, metavar='URL') +        archiveparser.set_defaults (func=self.handleArchive) + +        statusparser = subparsers.add_parser ('s', help='Get job status', add_help=False) +        statusparser.add_argument('id', help='Job id', metavar='UUID') +        statusparser.set_defaults (func=self.handleStatus) + +        abortparser = subparsers.add_parser ('r', help='Revoke/abort job', add_help=False) +        abortparser.add_argument('id', help='Job id', metavar='UUID') +        abortparser.set_defaults (func=self.handleAbort) + +        return parser + +    @voice +    async def handleArchive (self, user, args, reply): +        """ Handle the archive command """ + +        j = Job (args.url, user.name) +        assert j.id not in self.jobs, 'duplicate job id' +        self.jobs[j.id] = j + +        logger = self.logger.bind (id=j.id, user=user.name, url=args.url) + +        cmdline = ['crocoite-recursive', args.url, '--tempdir', self.tempdir, +                '--prefix', j.id + '-{host}-{date}-', '--policy', +                args.recursive, '--concurrency', str (args.concurrency), +                self.destdir] + +        showargs = { +                'recursive': args.recursive, +                'concurrency': args.concurrency, +                } +        strargs = ', '.join (map (lambda x: '{}={}'.format (*x), showargs.items ())) +        reply ('{} has been queued as {} with {}'.format (args.url, j.id, strargs)) +        logger.info ('queue', cmdline=cmdline) + +        async with self.processLimit: +            if j.status == Status.pending: +                # job was not aborted +                j.process = await asyncio.create_subprocess_exec (*cmdline, +                        stdout=asyncio.subprocess.PIPE, +                        stderr=asyncio.subprocess.DEVNULL, +                        stdin=asyncio.subprocess.DEVNULL) +                while True: +                    data = await j.process.stdout.readline () +                    if not data: +                        break + +                    # job is marked running after the first message is received from it +                    if j.status == Status.pending: +                        logger.info ('start') +                        j.status = Status.running + +                    data = json.loads (data) +                    msgid = data.get ('uuid') +                    if msgid == '24d92d16-770e-4088-b769-4020e127a7ff': +                        j.stats = data +                    elif msgid == '5b8498e4-868d-413c-a67e-004516b8452c': +                        j.rstats = data +                code = await j.process.wait () + +        if j.status == Status.running: +            logger.info ('finish') +            j.status = Status.finished +        j.finished = datetime.utcnow () + +        stats = j.stats +        rstats = j.rstats +        reply (j.formatStatus ()) + +    @jobExists +    async def handleStatus (self, user, args, reply, job): +        """ Handle status command """ + +        rstats = job.rstats +        reply (job.formatStatus ()) + +    @voice +    @jobExists +    async def handleAbort (self, user, args, reply, job): +        """ Handle abort command """ + +        if job.status not in {Status.pending, Status.running}: +            reply ('This job is not running.') +            return + +        job.status = Status.aborted +        self.logger.info ('abort', id=job.id, user=user.name) +        if job.process and job.process.returncode is None: +            job.process.terminate () + | 
