From 0b83e60c0b637ee4c9b7f2299dba8742e6b8fc5a Mon Sep 17 00:00:00 2001
From: Lars-Dominik Braun <lars@6xq.net>
Date: Sat, 22 Dec 2018 10:28:11 +0100
Subject: Switch -recursive to asyncio’s .cancel()
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

RecursiveController used a custom .cancel() method before. Instead we
can simply cancel .run() and handle the CancelledError inside run() and
fetch().
---
 crocoite/cli.py        |   5 ++-
 crocoite/controller.py | 108 +++++++++++++++++++++++++------------------------
 2 files changed, 58 insertions(+), 55 deletions(-)

(limited to 'crocoite')

diff --git a/crocoite/cli.py b/crocoite/cli.py
index b0ad53a..f9ef52c 100644
--- a/crocoite/cli.py
+++ b/crocoite/cli.py
@@ -120,11 +120,12 @@ def recursive ():
             tempdir=args.tempdir, prefix=args.prefix,
             concurrency=args.concurrency)
 
+    run = asyncio.ensure_future (controller.run ())
     loop = asyncio.get_event_loop()
-    stop = lambda signum: controller.cancel ()
+    stop = lambda signum: run.cancel ()
     loop.add_signal_handler (signal.SIGINT, stop, signal.SIGINT)
     loop.add_signal_handler (signal.SIGTERM, stop, signal.SIGTERM)
-    loop.run_until_complete(controller.run ())
+    loop.run_until_complete(run)
     loop.close()
 
     return 0
diff --git a/crocoite/controller.py b/crocoite/controller.py
index bf0d852..435f979 100644
--- a/crocoite/controller.py
+++ b/crocoite/controller.py
@@ -274,7 +274,7 @@ class RecursiveController:
     """
 
     __slots__ = ('url', 'output', 'command', 'logger', 'policy', 'have',
-            'pending', 'stats', 'prefix', 'tempdir', 'running', 'concurrency', '_quit')
+            'pending', 'stats', 'prefix', 'tempdir', 'running', 'concurrency')
 
     SCHEME_WHITELIST = {'http', 'https'}
 
@@ -293,8 +293,6 @@ class RecursiveController:
         self.concurrency = concurrency
         # keep in sync with StatsHandler
         self.stats = {'requests': 0, 'finished': 0, 'failed': 0, 'bytesRcv': 0, 'crashed': 0, 'ignored': 0}
-        # initiate graceful shutdown
-        self._quit = False
 
     async def fetch (self, url):
         """
@@ -327,38 +325,35 @@ class RecursiveController:
         destpath = os.path.join (self.output, os.path.basename (dest.name))
         command = list (map (formatCommand, self.command))
         logger.info ('fetch', uuid='1680f384-744c-4b8a-815b-7346e632e8db', command=command, destfile=destpath)
-        process = await asyncio.create_subprocess_exec (*command, stdout=asyncio.subprocess.PIPE,
-                stderr=asyncio.subprocess.DEVNULL, stdin=asyncio.subprocess.DEVNULL,
-                start_new_session=True)
-        while True:
-            data = await process.stdout.readline ()
-            if not data:
-                break
-            data = json.loads (data)
-            uuid = data.get ('uuid')
-            if uuid == '8ee5e9c9-1130-4c5c-88ff-718508546e0c':
-                links = set (self.policy (map (lambda x: x.with_fragment(None), data.get ('links', []))))
-                links.difference_update (self.have)
-                self.pending.update (links)
-            elif uuid == '24d92d16-770e-4088-b769-4020e127a7ff':
-                for k in self.stats.keys ():
-                    self.stats[k] += data.get (k, 0)
+        try:
+            process = await asyncio.create_subprocess_exec (*command, stdout=asyncio.subprocess.PIPE,
+                    stderr=asyncio.subprocess.DEVNULL, stdin=asyncio.subprocess.DEVNULL,
+                    start_new_session=True)
+            while True:
+                data = await process.stdout.readline ()
+                if not data:
+                    break
+                data = json.loads (data)
+                uuid = data.get ('uuid')
+                if uuid == '8ee5e9c9-1130-4c5c-88ff-718508546e0c':
+                    links = set (self.policy (map (lambda x: x.with_fragment(None), data.get ('links', []))))
+                    links.difference_update (self.have)
+                    self.pending.update (links)
+                elif uuid == '24d92d16-770e-4088-b769-4020e127a7ff':
+                    for k in self.stats.keys ():
+                        self.stats[k] += data.get (k, 0)
+                    logStats ()
+        except asyncio.CancelledError:
+            # graceful cancellation
+            process.send_signal (signal.SIGINT)
+        finally:
+            code = await process.wait()
+            if code == 0:
+                # atomically move once finished
+                os.rename (dest.name, destpath)
+            else:
+                self.stats['crashed'] += 1
                 logStats ()
-        code = await process.wait()
-        if code == 0:
-            # atomically move once finished
-            os.rename (dest.name, destpath)
-        else:
-            self.stats['crashed'] += 1
-            logStats ()
-
-    def cancel (self):
-        """ Gracefully cancel this job, waiting for existing workers to shut down """
-        self.logger.info ('cancel',
-                uuid='d58154c8-ec27-40f2-ab9e-e25c1b21cd88',
-                pending=len (self.pending), have=len (self.have),
-                running=len (self.running))
-        self._quit = True
 
     async def run (self):
         def log ():
@@ -367,24 +362,31 @@ class RecursiveController:
                     pending=len (self.pending), have=len (self.have),
                     running=len (self.running))
 
-        self.have = set ()
-        self.pending = set ([self.url])
-
-        while self.pending and not self._quit:
-            # since pending is a set this picks a random item, which is fine
-            u = self.pending.pop ()
-            self.have.add (u)
-            t = asyncio.ensure_future (self.fetch (u))
-            self.running.add (t)
-
+        try:
+            self.have = set ()
+            self.pending = set ([self.url])
+
+            while self.pending:
+                # since pending is a set this picks a random item, which is fine
+                u = self.pending.pop ()
+                self.have.add (u)
+                t = asyncio.ensure_future (self.fetch (u))
+                self.running.add (t)
+
+                log ()
+
+                if len (self.running) >= self.concurrency or not self.pending:
+                    done, pending = await asyncio.wait (self.running,
+                            return_when=asyncio.FIRST_COMPLETED)
+                    self.running.difference_update (done)
+        except asyncio.CancelledError:
+            self.logger.info ('cancel',
+                    uuid='d58154c8-ec27-40f2-ab9e-e25c1b21cd88',
+                    pending=len (self.pending), have=len (self.have),
+                    running=len (self.running))
+        finally:
+            done = await asyncio.gather (*self.running,
+                    return_exceptions=True)
+            self.running = set ()
             log ()
 
-            if len (self.running) >= self.concurrency or not self.pending:
-                done, pending = await asyncio.wait (self.running,
-                        return_when=asyncio.FIRST_COMPLETED)
-                self.running.difference_update (done)
-
-        done = asyncio.gather (*self.running)
-        self.running = set ()
-        log ()
-
-- 
cgit v1.2.3