summaryrefslogtreecommitdiff
path: root/crocoite
diff options
context:
space:
mode:
authorLars-Dominik Braun <lars@6xq.net>2018-10-14 12:35:07 +0200
committerLars-Dominik Braun <lars@6xq.net>2018-10-14 12:35:07 +0200
commit07994fb6b72b0c84d2ee2c69e5afdb204d33d5e6 (patch)
treeb3c052e605dc1266118a4f660afcfcafad5ebc26 /crocoite
parent3e69f8b34a48ffa4df4805c53aeaba144d91c6bc (diff)
downloadcrocoite-07994fb6b72b0c84d2ee2c69e5afdb204d33d5e6.tar.gz
crocoite-07994fb6b72b0c84d2ee2c69e5afdb204d33d5e6.tar.bz2
crocoite-07994fb6b72b0c84d2ee2c69e5afdb204d33d5e6.zip
irc: Graceful bot shutdown
Wait for remaining jobs to finish without accepting new ones, but still allow some interaction with the bot (status/revoke).
Diffstat (limited to 'crocoite')
-rw-r--r--crocoite/cli.py10
-rw-r--r--crocoite/irc.py75
-rw-r--r--crocoite/test_irc.py41
3 files changed, 110 insertions, 16 deletions
diff --git a/crocoite/cli.py b/crocoite/cli.py
index 55ff4a1..913db7c 100644
--- a/crocoite/cli.py
+++ b/crocoite/cli.py
@@ -125,6 +125,7 @@ def irc ():
config.read (args.config)
s = config['irc']
+ loop = asyncio.get_event_loop()
bot = Chromebot (
host=s.get ('host'),
port=s.getint ('port'),
@@ -134,7 +135,10 @@ def irc ():
tempdir=s.get ('tempdir'),
destdir=s.get ('destdir'),
processLimit=s.getint ('process_limit'),
- logger=logger)
- bot.loop.create_task(bot.connect())
- bot.loop.run_forever()
+ logger=logger,
+ loop=loop)
+ stop = lambda signum: bot.cancel ()
+ loop.add_signal_handler (signal.SIGINT, stop, signal.SIGINT)
+ loop.add_signal_handler (signal.SIGTERM, stop, signal.SIGTERM)
+ loop.run_until_complete(bot.run ())
diff --git a/crocoite/irc.py b/crocoite/irc.py
index c955337..7d1a96d 100644
--- a/crocoite/irc.py
+++ b/crocoite/irc.py
@@ -160,6 +160,36 @@ class ReplyContext:
def __call__ (self, message):
self.client.send ('PRIVMSG', target=self.target, message='{}: {}'.format (self.user.name, message))
+class RefCountEvent:
+ """
+ Ref-counted event that triggers if a) armed and b) refcount drops to zero.
+
+ Must be used as a context manager.
+ """
+ __slots__ = ('count', 'event', 'armed')
+
+ def __init__ (self):
+ self.armed = False
+ self.count = 0
+ self.event = asyncio.Event ()
+
+ def __enter__ (self):
+ self.count += 1
+ self.event.clear ()
+
+ def __exit__ (self, exc_type, exc_val, exc_tb):
+ self.count -= 1
+ if self.armed and self.count == 0:
+ self.event.set ()
+
+ async def wait (self):
+ await self.event.wait ()
+
+ def arm (self):
+ self.armed = True
+ if self.count == 0:
+ self.event.set ()
+
class ArgparseBot (bottom.Client):
"""
Simple IRC bot using argparse
@@ -167,10 +197,10 @@ class ArgparseBot (bottom.Client):
Tracks user’s modes, reconnects on disconnect
"""
- __slots__ = ('channels', 'nick', 'parser', 'users')
+ __slots__ = ('channels', 'nick', 'parser', 'users', '_quit')
- def __init__ (self, host, port, ssl, nick, logger, channels=[]):
- super().__init__ (host=host, port=port, ssl=ssl)
+ def __init__ (self, host, port, ssl, nick, logger, channels=[], loop=None):
+ super().__init__ (host=host, port=port, ssl=ssl, loop=loop)
self.channels = channels
self.nick = nick
# map channel -> nick -> user
@@ -178,6 +208,10 @@ class ArgparseBot (bottom.Client):
self.logger = logger
self.parser = self.getParser ()
+ # bot does not accept new queries in shutdown mode, unless explicitly
+ # permitted by the parser
+ self._quit = RefCountEvent ()
+
# register bottom event handler
self.on('CLIENT_CONNECT', self.onConnect)
self.on('PING', self.onKeepalive)
@@ -193,6 +227,16 @@ class ArgparseBot (bottom.Client):
def getParser (self):
pass
+ def cancel (self):
+ self.logger.info ('cancel', uuid='1eb34aea-a854-4fec-90b2-7f8a3812a9cd')
+ self._quit.arm ()
+
+ async def run (self):
+ await self.connect ()
+ await self._quit.wait ()
+ self.send ('QUIT', message='Bye.')
+ await self.disconnect ()
+
async def onConnect (self, **kwargs):
self.logger.info ('connect', nick=self.nick)
@@ -282,14 +326,19 @@ class ArgparseBot (bottom.Client):
reply ('Sorry, I don’t understand {}'.format (command))
return
- await args.func (user=user, args=args, reply=reply)
+ if self._quit.armed and not getattr (args, 'allowOnShutdown', False):
+ reply ('Sorry, I’m shutting down and cannot accept your request right now.')
+ else:
+ with self._quit:
+ await args.func (user=user, args=args, reply=reply)
async def onDisconnect (**kwargs):
""" Auto-reconnect """
self.logger.info ('disconnect')
- await asynio.sleep (10, loop=self.loop)
- self.logger.info ('reconnect')
- await self.connect ()
+ if not self._quit.armed:
+ await asynio.sleep (10, loop=self.loop)
+ self.logger.info ('reconnect')
+ await self.connect ()
def voice (func):
""" Calling user must have voice or ops """
@@ -323,9 +372,10 @@ class Chromebot (ArgparseBot):
__slots__ = ('jobs', 'tempdir', 'destdir', 'processLimit')
def __init__ (self, host, port, ssl, nick, logger, channels=[],
- tempdir=tempfile.gettempdir(), destdir='.', processLimit=1):
+ tempdir=tempfile.gettempdir(), destdir='.', processLimit=1,
+ loop=None):
super().__init__ (host=host, port=port, ssl=ssl, nick=nick,
- logger=logger, channels=channels)
+ logger=logger, channels=channels, loop=loop)
self.jobs = {}
self.tempdir = tempdir
@@ -347,11 +397,11 @@ class Chromebot (ArgparseBot):
statusparser = subparsers.add_parser ('s', help='Get job status', add_help=False)
statusparser.add_argument('id', help='Job id', metavar='UUID')
- statusparser.set_defaults (func=self.handleStatus)
+ statusparser.set_defaults (func=self.handleStatus, allowOnShutdown=True)
abortparser = subparsers.add_parser ('r', help='Revoke/abort job', add_help=False)
abortparser.add_argument('id', help='Job id', metavar='UUID')
- abortparser.set_defaults (func=self.handleAbort)
+ abortparser.set_defaults (func=self.handleAbort, allowOnShutdown=True)
return parser
@@ -384,7 +434,8 @@ class Chromebot (ArgparseBot):
j.process = await asyncio.create_subprocess_exec (*cmdline,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL,
- stdin=asyncio.subprocess.DEVNULL)
+ stdin=asyncio.subprocess.DEVNULL,
+ start_new_session=True)
while True:
data = await j.process.stdout.readline ()
if not data:
diff --git a/crocoite/test_irc.py b/crocoite/test_irc.py
index 268c604..4d80a6d 100644
--- a/crocoite/test_irc.py
+++ b/crocoite/test_irc.py
@@ -1,5 +1,25 @@
+# Copyright (c) 2017–2018 crocoite contributors
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
import pytest
-from .irc import ArgparseBot
+from .irc import ArgparseBot, RefCountEvent
def test_mode_parse ():
assert ArgparseBot.parseMode ('+a') == [('+', 'a')]
@@ -12,3 +32,22 @@ def test_mode_parse ():
assert ArgparseBot.parseMode ('-a+b') == [('-', 'a'), ('+', 'b')]
assert ArgparseBot.parseMode ('-ab+cd') == [('-', 'a'), ('-', 'b'), ('+', 'c'), ('+', 'd')]
+@pytest.fixture
+def event ():
+ return RefCountEvent ()
+
+def test_refcountevent_arm (event):
+ event.arm ()
+ assert event.event.is_set ()
+
+def test_refcountevent_ctxmgr (event):
+ with event:
+ assert event.count == 1
+ with event:
+ assert event.count == 2
+
+def test_refcountevent_arm_with (event):
+ with event:
+ event.arm ()
+ assert not event.event.is_set ()
+ assert event.event.is_set ()