diff options
author | Lars-Dominik Braun <lars@6xq.net> | 2018-10-14 12:35:07 +0200 |
---|---|---|
committer | Lars-Dominik Braun <lars@6xq.net> | 2018-10-14 12:35:07 +0200 |
commit | 07994fb6b72b0c84d2ee2c69e5afdb204d33d5e6 (patch) | |
tree | b3c052e605dc1266118a4f660afcfcafad5ebc26 /crocoite | |
parent | 3e69f8b34a48ffa4df4805c53aeaba144d91c6bc (diff) | |
download | crocoite-07994fb6b72b0c84d2ee2c69e5afdb204d33d5e6.tar.gz crocoite-07994fb6b72b0c84d2ee2c69e5afdb204d33d5e6.tar.bz2 crocoite-07994fb6b72b0c84d2ee2c69e5afdb204d33d5e6.zip |
irc: Graceful bot shutdown
Wait for remaining jobs to finish without accepting new ones, but still
allow some interaction with the bot (status/revoke).
Diffstat (limited to 'crocoite')
-rw-r--r-- | crocoite/cli.py | 10 | ||||
-rw-r--r-- | crocoite/irc.py | 75 | ||||
-rw-r--r-- | crocoite/test_irc.py | 41 |
3 files changed, 110 insertions, 16 deletions
diff --git a/crocoite/cli.py b/crocoite/cli.py index 55ff4a1..913db7c 100644 --- a/crocoite/cli.py +++ b/crocoite/cli.py @@ -125,6 +125,7 @@ def irc (): config.read (args.config) s = config['irc'] + loop = asyncio.get_event_loop() bot = Chromebot ( host=s.get ('host'), port=s.getint ('port'), @@ -134,7 +135,10 @@ def irc (): tempdir=s.get ('tempdir'), destdir=s.get ('destdir'), processLimit=s.getint ('process_limit'), - logger=logger) - bot.loop.create_task(bot.connect()) - bot.loop.run_forever() + logger=logger, + loop=loop) + stop = lambda signum: bot.cancel () + loop.add_signal_handler (signal.SIGINT, stop, signal.SIGINT) + loop.add_signal_handler (signal.SIGTERM, stop, signal.SIGTERM) + loop.run_until_complete(bot.run ()) diff --git a/crocoite/irc.py b/crocoite/irc.py index c955337..7d1a96d 100644 --- a/crocoite/irc.py +++ b/crocoite/irc.py @@ -160,6 +160,36 @@ class ReplyContext: def __call__ (self, message): self.client.send ('PRIVMSG', target=self.target, message='{}: {}'.format (self.user.name, message)) +class RefCountEvent: + """ + Ref-counted event that triggers if a) armed and b) refcount drops to zero. + + Must be used as a context manager. + """ + __slots__ = ('count', 'event', 'armed') + + def __init__ (self): + self.armed = False + self.count = 0 + self.event = asyncio.Event () + + def __enter__ (self): + self.count += 1 + self.event.clear () + + def __exit__ (self, exc_type, exc_val, exc_tb): + self.count -= 1 + if self.armed and self.count == 0: + self.event.set () + + async def wait (self): + await self.event.wait () + + def arm (self): + self.armed = True + if self.count == 0: + self.event.set () + class ArgparseBot (bottom.Client): """ Simple IRC bot using argparse @@ -167,10 +197,10 @@ class ArgparseBot (bottom.Client): Tracks user’s modes, reconnects on disconnect """ - __slots__ = ('channels', 'nick', 'parser', 'users') + __slots__ = ('channels', 'nick', 'parser', 'users', '_quit') - def __init__ (self, host, port, ssl, nick, logger, channels=[]): - super().__init__ (host=host, port=port, ssl=ssl) + def __init__ (self, host, port, ssl, nick, logger, channels=[], loop=None): + super().__init__ (host=host, port=port, ssl=ssl, loop=loop) self.channels = channels self.nick = nick # map channel -> nick -> user @@ -178,6 +208,10 @@ class ArgparseBot (bottom.Client): self.logger = logger self.parser = self.getParser () + # bot does not accept new queries in shutdown mode, unless explicitly + # permitted by the parser + self._quit = RefCountEvent () + # register bottom event handler self.on('CLIENT_CONNECT', self.onConnect) self.on('PING', self.onKeepalive) @@ -193,6 +227,16 @@ class ArgparseBot (bottom.Client): def getParser (self): pass + def cancel (self): + self.logger.info ('cancel', uuid='1eb34aea-a854-4fec-90b2-7f8a3812a9cd') + self._quit.arm () + + async def run (self): + await self.connect () + await self._quit.wait () + self.send ('QUIT', message='Bye.') + await self.disconnect () + async def onConnect (self, **kwargs): self.logger.info ('connect', nick=self.nick) @@ -282,14 +326,19 @@ class ArgparseBot (bottom.Client): reply ('Sorry, I don’t understand {}'.format (command)) return - await args.func (user=user, args=args, reply=reply) + if self._quit.armed and not getattr (args, 'allowOnShutdown', False): + reply ('Sorry, I’m shutting down and cannot accept your request right now.') + else: + with self._quit: + await args.func (user=user, args=args, reply=reply) async def onDisconnect (**kwargs): """ Auto-reconnect """ self.logger.info ('disconnect') - await asynio.sleep (10, loop=self.loop) - self.logger.info ('reconnect') - await self.connect () + if not self._quit.armed: + await asynio.sleep (10, loop=self.loop) + self.logger.info ('reconnect') + await self.connect () def voice (func): """ Calling user must have voice or ops """ @@ -323,9 +372,10 @@ class Chromebot (ArgparseBot): __slots__ = ('jobs', 'tempdir', 'destdir', 'processLimit') def __init__ (self, host, port, ssl, nick, logger, channels=[], - tempdir=tempfile.gettempdir(), destdir='.', processLimit=1): + tempdir=tempfile.gettempdir(), destdir='.', processLimit=1, + loop=None): super().__init__ (host=host, port=port, ssl=ssl, nick=nick, - logger=logger, channels=channels) + logger=logger, channels=channels, loop=loop) self.jobs = {} self.tempdir = tempdir @@ -347,11 +397,11 @@ class Chromebot (ArgparseBot): statusparser = subparsers.add_parser ('s', help='Get job status', add_help=False) statusparser.add_argument('id', help='Job id', metavar='UUID') - statusparser.set_defaults (func=self.handleStatus) + statusparser.set_defaults (func=self.handleStatus, allowOnShutdown=True) abortparser = subparsers.add_parser ('r', help='Revoke/abort job', add_help=False) abortparser.add_argument('id', help='Job id', metavar='UUID') - abortparser.set_defaults (func=self.handleAbort) + abortparser.set_defaults (func=self.handleAbort, allowOnShutdown=True) return parser @@ -384,7 +434,8 @@ class Chromebot (ArgparseBot): j.process = await asyncio.create_subprocess_exec (*cmdline, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, - stdin=asyncio.subprocess.DEVNULL) + stdin=asyncio.subprocess.DEVNULL, + start_new_session=True) while True: data = await j.process.stdout.readline () if not data: diff --git a/crocoite/test_irc.py b/crocoite/test_irc.py index 268c604..4d80a6d 100644 --- a/crocoite/test_irc.py +++ b/crocoite/test_irc.py @@ -1,5 +1,25 @@ +# Copyright (c) 2017–2018 crocoite contributors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + import pytest -from .irc import ArgparseBot +from .irc import ArgparseBot, RefCountEvent def test_mode_parse (): assert ArgparseBot.parseMode ('+a') == [('+', 'a')] @@ -12,3 +32,22 @@ def test_mode_parse (): assert ArgparseBot.parseMode ('-a+b') == [('-', 'a'), ('+', 'b')] assert ArgparseBot.parseMode ('-ab+cd') == [('-', 'a'), ('-', 'b'), ('+', 'c'), ('+', 'd')] +@pytest.fixture +def event (): + return RefCountEvent () + +def test_refcountevent_arm (event): + event.arm () + assert event.event.is_set () + +def test_refcountevent_ctxmgr (event): + with event: + assert event.count == 1 + with event: + assert event.count == 2 + +def test_refcountevent_arm_with (event): + with event: + event.arm () + assert not event.event.is_set () + assert event.event.is_set () |