diff options
author | Lars-Dominik Braun <lars@6xq.net> | 2019-11-17 10:09:37 +0100 |
---|---|---|
committer | Lars-Dominik Braun <lars@6xq.net> | 2019-11-17 10:09:37 +0100 |
commit | 1e6ad5d702181bce6aeb3d0704c36f124417227d (patch) | |
tree | 62946bc5090b10d24a4cbbf7fe09dadc3b40c6c5 | |
parent | 41f342e12b975e785de9d755d38eb92cf38f5ec5 (diff) | |
download | lulua-1e6ad5d702181bce6aeb3d0704c36f124417227d.tar.gz lulua-1e6ad5d702181bce6aeb3d0704c36f124417227d.tar.bz2 lulua-1e6ad5d702181bce6aeb3d0704c36f124417227d.zip |
Add more tests
-rw-r--r-- | lulua/stats.py | 46 | ||||
-rw-r--r-- | lulua/test_stats.py | 88 | ||||
-rw-r--r-- | lulua/test_text.py | 48 | ||||
-rw-r--r-- | lulua/text.py | 36 | ||||
-rw-r--r-- | lulua/writer.py | 4 |
5 files changed, 200 insertions, 22 deletions
diff --git a/lulua/stats.py b/lulua/stats.py index a7980d6..4d455c5 100644 --- a/lulua/stats.py +++ b/lulua/stats.py @@ -50,6 +50,12 @@ def updateDictOp (a, b, op): class Stats: name = 'invalid' + def process (self, event): + raise NotImplementedError + + def update (self, other): + raise NotImplementedError + class RunlenStats (Stats): __slots__ = ('lastHand', 'perHandRunlenDist', 'curPerHandRunlen', 'fingerRunlen', 'lastFinger', 'fingerRunlenDist', 'writer') @@ -67,6 +73,11 @@ class RunlenStats (Stats): self.fingerRunlenDist = dict (((x, y), defaultdict (int)) for x, y in product (iter (Direction), iter (FingerType))) self.fingerRunlen = 0 + def __eq__ (self, other): + if not isinstance (other, RunlenStats): + return NotImplemented + return self.perHandRunlenDist == other.perHandRunlenDist + def process (self, event): if isinstance (event, ButtonCombination): assert len (event.buttons) == 1 @@ -90,6 +101,8 @@ class RunlenStats (Stats): self.lastFinger = None self.fingerRunlen = 0 + else: + raise ValueError () def update (self, other): updateDictOp (self.perHandRunlenDist, other.perHandRunlenDist, operator.add) @@ -106,6 +119,13 @@ class SimpleStats (Stats): self.combinations = defaultdict (int) self.unknown = defaultdict (int) + def __eq__ (self, other): + if not isinstance (other, SimpleStats): + return NotImplemented + return self.buttons == other.buttons and \ + self.combinations == other.combinations and \ + self.unknown == other.unknown + def process (self, event): if isinstance (event, SkipEvent): self.unknown[event.char] += 1 @@ -113,6 +133,8 @@ class SimpleStats (Stats): for b in event: self.buttons[b] += 1 self.combinations[event] += 1 + else: + raise ValueError () def update (self, other): updateDictOp (self.buttons, other.buttons, operator.add) @@ -138,6 +160,11 @@ class TriadStats (Stats): keyboard = self._writer.layout.keyboard self._ignored = frozenset (keyboard[x] for x in ('Fl_space', 'Fr_space', 'CD_ret', 'Cl_tab')) + def __eq__ (self, other): + if not isinstance (other, TriadStats): + return NotImplemented + return self.triads == other.triads + def process (self, event): if isinstance (event, SkipEvent): # reset @@ -154,6 +181,8 @@ class TriadStats (Stats): if len (self._triad) == 3: k = tuple (self._triad) self.triads[k] += 1 + else: + raise ValueError () def update (self, other): updateDictOp (self.triads, other.triads, operator.add) @@ -173,6 +202,11 @@ class WordStats (Stats): self._currentWord = [] self.words = defaultdict (int) + def __eq__ (self, other): + if not isinstance (other, WordStats): + return NotImplemented + return self.words == other.words + def process (self, event): if isinstance (event, SkipEvent): # reset @@ -188,6 +222,8 @@ class WordStats (Stats): elif self._currentWord: self.words[''.join (self._currentWord)] += 1 self._currentWord = [] + else: + raise ValueError () def update (self, other): updateDictOp (self.words, other.words, operator.add) @@ -201,11 +237,15 @@ def unpickleAll (fd): except EOFError: break -def combine (args): - keyboard = defaultKeyboards[args.keyboard] +def makeCombined (keyboard): + """ Create a dict which contains initialized stats, ready for combining (not actual writing!) """ layout = defaultLayouts['null'].specialize (keyboard) w = Writer (layout) - combined = dict ((cls.name, cls(w)) for cls in allStats) + return dict ((cls.name, cls(w)) for cls in allStats) + +def combine (args): + keyboard = defaultKeyboards[args.keyboard] + combined = makeCombined (keyboard) for r in unpickleAll (sys.stdin.buffer): for s in allStats: combined[s.name].update (r[s.name]) diff --git a/lulua/test_stats.py b/lulua/test_stats.py index 9e3ed77..3259bcd 100644 --- a/lulua/test_stats.py +++ b/lulua/test_stats.py @@ -21,7 +21,10 @@ import operator import pytest -from .stats import updateDictOp, approx +from .stats import updateDictOp, approx, SimpleStats, TriadStats, allStats +from .keyboard import defaultKeyboards +from .layout import defaultLayouts, ButtonCombination +from .writer import Writer, SkipEvent def test_updateDictOp (): a = {1: 3} @@ -50,3 +53,86 @@ def test_approx (): assert approx (10**9) == (1, 0, 'billion') assert approx (10**12) == (1000, 0, 'billion') +@pytest.fixture +def writer (): + """ Return a default, safe writer with known properties for a fixed layout """ + keyboard = defaultKeyboards['ibmpc105'] + layout = defaultLayouts['ar-lulua'].specialize (keyboard) + return Writer (layout) + +def test_simplestats (writer): + keyboard = writer.layout.keyboard + + s = SimpleStats (writer) + assert not s.unknown + assert not s.combinations + assert not s.buttons + + s.process (SkipEvent ('a')) + assert len (s.unknown) == 1 and s.unknown['a'] == 1 + # no change for those + assert not s.combinations + assert not s.buttons + + dlcaps = keyboard['Dl_caps'] + bl1 = keyboard['Bl1'] + comb = ButtonCombination (frozenset ([dlcaps]), frozenset ([bl1])) + s.process (comb) + assert s.buttons[dlcaps] == 1 + assert s.buttons[bl1] == 1 + assert s.combinations[comb] == 1 + # no change + assert len (s.unknown) == 1 and s.unknown['a'] == 1 + + s2 = SimpleStats (writer) + s2.update (s) + assert s2 == s + +def test_triadstats (writer): + keyboard = writer.layout.keyboard + + s = TriadStats (writer) + assert not s.triads + + s.process (SkipEvent ('a')) + # should not change anything + assert not s.triads + + dlcaps = keyboard['Dl_caps'] + bl1 = keyboard['Bl1'] + comb = ButtonCombination (frozenset ([dlcaps]), frozenset ([bl1])) + for i in range (3): + s.process (comb) + assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 1 + + # sliding window -> increase + s.process (comb) + assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2 + + # clear sliding window + s.process (SkipEvent ('a')) + assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2 + + # thus no change here + for i in range (2): + s.process (comb) + assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2 + + # but here + s.process (comb) + assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 3 + +def test_stats_process_value (writer): + """ Make sure stats classes reject invalid values for .process() """ + + for cls in allStats: + s = cls (writer) + with pytest.raises (ValueError): + s.process (1) + + s.process (SkipEvent ('a')) + s2 = cls (writer) + s2.update (s) + assert s2 == s + assert not s2 == 1 + diff --git a/lulua/test_text.py b/lulua/test_text.py new file mode 100644 index 0000000..65aa3a1 --- /dev/null +++ b/lulua/test_text.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 lulua contributors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from .text import charMap, mapChars + +def test_map_chars_mapped (): + """ Make sure all chars in the map are mapped correctly """ + + inText = '' + expectText = '' + for k, v in charMap.items (): + inText += k + expectText += v + + outText = mapChars (inText, charMap) + assert outText == expectText + +def test_map_chars_not_mapped (): + """ No chars not in the map should be mapped """ + + inText = '' + expectText = '' + for k, v in charMap.items (): + inText += v + expectText += v + inText += 'a' + expectText += 'a' + + outText = mapChars (inText, charMap) + assert outText == expectText + diff --git a/lulua/text.py b/lulua/text.py index 2d8398d..1d46af8 100644 --- a/lulua/text.py +++ b/lulua/text.py @@ -37,7 +37,7 @@ from html5lib.filters.base import Filter from .keyboard import defaultKeyboards from .layout import defaultLayouts from .writer import Writer -from .stats import allStats +from .stats import allStats, makeCombined def iterchar (fd): batchsize = 1*1024*1024 @@ -244,32 +244,28 @@ charMap = { '\u00a0': ' ', } -def writeWorker (args, inq, outq): +def mapChars (text, m): + """ For all characters in text, replace if found in map m or keep as-is """ + return ''.join (map (lambda x: m.get (x, x), text)) + +def writeWorker (layout, sourceFunc, inq, outq): try: keyboard = defaultKeyboards['ibmpc105'] - layout = defaultLayouts['null'].specialize (keyboard) - w = Writer (layout) - combined = dict ((cls.name, cls(w)) for cls in allStats) + combined = makeCombined (keyboard) itemsProcessed = 0 while True: - keyboard = defaultKeyboards[args.keyboard] - layout = defaultLayouts[args.layout].specialize (keyboard) - w = Writer (layout) - item = inq.get () if item is None: break # extract (can be multiple items per source) - for text in sources[args.source] (item): - text = ''.join (map (lambda x: charMap.get (x, x), text)) - # XXX sanity checks, disable - for c in charMap.keys (): - if c in text: - #print (c, 'is in text', file=sys.stderr) - assert False, c + for text in sourceFunc (item): + # map chars + text = mapChars (text, charMap) + # init a new writer for every item + w = Writer (layout) # stats stats = [cls(w) for cls in allStats] for match, event in w.type (StringIO (text)): @@ -309,6 +305,9 @@ def write (): else: logging.basicConfig (level=logging.INFO) + keyboard = defaultKeyboards[args.keyboard] + layout = defaultLayouts[args.layout].specialize (keyboard) + # limit queue sizes to limit memory usage inq = Queue (args.jobs*2) outq = Queue (args.jobs+1) @@ -316,7 +315,10 @@ def write (): logging.info (f'using {args.jobs} workers') workers = [] for i in range (args.jobs): - p = Process(target=writeWorker, args=(args, inq, outq), daemon=True, name=f'worker-{i}') + p = Process(target=writeWorker, + args=(layout, sources[args.source], inq, outq), + daemon=True, + name=f'worker-{i}') p.start() workers.append (p) diff --git a/lulua/writer.py b/lulua/writer.py index bbe1efb..94ad1b4 100644 --- a/lulua/writer.py +++ b/lulua/writer.py @@ -20,6 +20,7 @@ import json from operator import itemgetter +from typing import Text from .layout import * @@ -109,7 +110,8 @@ defaultFingermap = { class SkipEvent: __slots__ = ('char', ) - def __init__ (self, char): + def __init__ (self, char: Text): + assert len (char) == 1 self.char = char def __eq__ (self, other): |