summaryrefslogtreecommitdiff
path: root/lulua
diff options
context:
space:
mode:
authorLars-Dominik Braun <lars@6xq.net>2019-11-17 10:09:37 +0100
committerLars-Dominik Braun <lars@6xq.net>2019-11-17 10:09:37 +0100
commit1e6ad5d702181bce6aeb3d0704c36f124417227d (patch)
tree62946bc5090b10d24a4cbbf7fe09dadc3b40c6c5 /lulua
parent41f342e12b975e785de9d755d38eb92cf38f5ec5 (diff)
downloadlulua-1e6ad5d702181bce6aeb3d0704c36f124417227d.tar.gz
lulua-1e6ad5d702181bce6aeb3d0704c36f124417227d.tar.bz2
lulua-1e6ad5d702181bce6aeb3d0704c36f124417227d.zip
Add more tests
Diffstat (limited to 'lulua')
-rw-r--r--lulua/stats.py46
-rw-r--r--lulua/test_stats.py88
-rw-r--r--lulua/test_text.py48
-rw-r--r--lulua/text.py36
-rw-r--r--lulua/writer.py4
5 files changed, 200 insertions, 22 deletions
diff --git a/lulua/stats.py b/lulua/stats.py
index a7980d6..4d455c5 100644
--- a/lulua/stats.py
+++ b/lulua/stats.py
@@ -50,6 +50,12 @@ def updateDictOp (a, b, op):
class Stats:
name = 'invalid'
+ def process (self, event):
+ raise NotImplementedError
+
+ def update (self, other):
+ raise NotImplementedError
+
class RunlenStats (Stats):
__slots__ = ('lastHand', 'perHandRunlenDist', 'curPerHandRunlen',
'fingerRunlen', 'lastFinger', 'fingerRunlenDist', 'writer')
@@ -67,6 +73,11 @@ class RunlenStats (Stats):
self.fingerRunlenDist = dict (((x, y), defaultdict (int)) for x, y in product (iter (Direction), iter (FingerType)))
self.fingerRunlen = 0
+ def __eq__ (self, other):
+ if not isinstance (other, RunlenStats):
+ return NotImplemented
+ return self.perHandRunlenDist == other.perHandRunlenDist
+
def process (self, event):
if isinstance (event, ButtonCombination):
assert len (event.buttons) == 1
@@ -90,6 +101,8 @@ class RunlenStats (Stats):
self.lastFinger = None
self.fingerRunlen = 0
+ else:
+ raise ValueError ()
def update (self, other):
updateDictOp (self.perHandRunlenDist, other.perHandRunlenDist, operator.add)
@@ -106,6 +119,13 @@ class SimpleStats (Stats):
self.combinations = defaultdict (int)
self.unknown = defaultdict (int)
+ def __eq__ (self, other):
+ if not isinstance (other, SimpleStats):
+ return NotImplemented
+ return self.buttons == other.buttons and \
+ self.combinations == other.combinations and \
+ self.unknown == other.unknown
+
def process (self, event):
if isinstance (event, SkipEvent):
self.unknown[event.char] += 1
@@ -113,6 +133,8 @@ class SimpleStats (Stats):
for b in event:
self.buttons[b] += 1
self.combinations[event] += 1
+ else:
+ raise ValueError ()
def update (self, other):
updateDictOp (self.buttons, other.buttons, operator.add)
@@ -138,6 +160,11 @@ class TriadStats (Stats):
keyboard = self._writer.layout.keyboard
self._ignored = frozenset (keyboard[x] for x in ('Fl_space', 'Fr_space', 'CD_ret', 'Cl_tab'))
+ def __eq__ (self, other):
+ if not isinstance (other, TriadStats):
+ return NotImplemented
+ return self.triads == other.triads
+
def process (self, event):
if isinstance (event, SkipEvent):
# reset
@@ -154,6 +181,8 @@ class TriadStats (Stats):
if len (self._triad) == 3:
k = tuple (self._triad)
self.triads[k] += 1
+ else:
+ raise ValueError ()
def update (self, other):
updateDictOp (self.triads, other.triads, operator.add)
@@ -173,6 +202,11 @@ class WordStats (Stats):
self._currentWord = []
self.words = defaultdict (int)
+ def __eq__ (self, other):
+ if not isinstance (other, WordStats):
+ return NotImplemented
+ return self.words == other.words
+
def process (self, event):
if isinstance (event, SkipEvent):
# reset
@@ -188,6 +222,8 @@ class WordStats (Stats):
elif self._currentWord:
self.words[''.join (self._currentWord)] += 1
self._currentWord = []
+ else:
+ raise ValueError ()
def update (self, other):
updateDictOp (self.words, other.words, operator.add)
@@ -201,11 +237,15 @@ def unpickleAll (fd):
except EOFError:
break
-def combine (args):
- keyboard = defaultKeyboards[args.keyboard]
+def makeCombined (keyboard):
+ """ Create a dict which contains initialized stats, ready for combining (not actual writing!) """
layout = defaultLayouts['null'].specialize (keyboard)
w = Writer (layout)
- combined = dict ((cls.name, cls(w)) for cls in allStats)
+ return dict ((cls.name, cls(w)) for cls in allStats)
+
+def combine (args):
+ keyboard = defaultKeyboards[args.keyboard]
+ combined = makeCombined (keyboard)
for r in unpickleAll (sys.stdin.buffer):
for s in allStats:
combined[s.name].update (r[s.name])
diff --git a/lulua/test_stats.py b/lulua/test_stats.py
index 9e3ed77..3259bcd 100644
--- a/lulua/test_stats.py
+++ b/lulua/test_stats.py
@@ -21,7 +21,10 @@
import operator
import pytest
-from .stats import updateDictOp, approx
+from .stats import updateDictOp, approx, SimpleStats, TriadStats, allStats
+from .keyboard import defaultKeyboards
+from .layout import defaultLayouts, ButtonCombination
+from .writer import Writer, SkipEvent
def test_updateDictOp ():
a = {1: 3}
@@ -50,3 +53,86 @@ def test_approx ():
assert approx (10**9) == (1, 0, 'billion')
assert approx (10**12) == (1000, 0, 'billion')
+@pytest.fixture
+def writer ():
+ """ Return a default, safe writer with known properties for a fixed layout """
+ keyboard = defaultKeyboards['ibmpc105']
+ layout = defaultLayouts['ar-lulua'].specialize (keyboard)
+ return Writer (layout)
+
+def test_simplestats (writer):
+ keyboard = writer.layout.keyboard
+
+ s = SimpleStats (writer)
+ assert not s.unknown
+ assert not s.combinations
+ assert not s.buttons
+
+ s.process (SkipEvent ('a'))
+ assert len (s.unknown) == 1 and s.unknown['a'] == 1
+ # no change for those
+ assert not s.combinations
+ assert not s.buttons
+
+ dlcaps = keyboard['Dl_caps']
+ bl1 = keyboard['Bl1']
+ comb = ButtonCombination (frozenset ([dlcaps]), frozenset ([bl1]))
+ s.process (comb)
+ assert s.buttons[dlcaps] == 1
+ assert s.buttons[bl1] == 1
+ assert s.combinations[comb] == 1
+ # no change
+ assert len (s.unknown) == 1 and s.unknown['a'] == 1
+
+ s2 = SimpleStats (writer)
+ s2.update (s)
+ assert s2 == s
+
+def test_triadstats (writer):
+ keyboard = writer.layout.keyboard
+
+ s = TriadStats (writer)
+ assert not s.triads
+
+ s.process (SkipEvent ('a'))
+ # should not change anything
+ assert not s.triads
+
+ dlcaps = keyboard['Dl_caps']
+ bl1 = keyboard['Bl1']
+ comb = ButtonCombination (frozenset ([dlcaps]), frozenset ([bl1]))
+ for i in range (3):
+ s.process (comb)
+ assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 1
+
+ # sliding window -> increase
+ s.process (comb)
+ assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2
+
+ # clear sliding window
+ s.process (SkipEvent ('a'))
+ assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2
+
+ # thus no change here
+ for i in range (2):
+ s.process (comb)
+ assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2
+
+ # but here
+ s.process (comb)
+ assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 3
+
+def test_stats_process_value (writer):
+ """ Make sure stats classes reject invalid values for .process() """
+
+ for cls in allStats:
+ s = cls (writer)
+ with pytest.raises (ValueError):
+ s.process (1)
+
+ s.process (SkipEvent ('a'))
+ s2 = cls (writer)
+ s2.update (s)
+ assert s2 == s
+ assert not s2 == 1
+
diff --git a/lulua/test_text.py b/lulua/test_text.py
new file mode 100644
index 0000000..65aa3a1
--- /dev/null
+++ b/lulua/test_text.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2019 lulua contributors
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+from .text import charMap, mapChars
+
+def test_map_chars_mapped ():
+ """ Make sure all chars in the map are mapped correctly """
+
+ inText = ''
+ expectText = ''
+ for k, v in charMap.items ():
+ inText += k
+ expectText += v
+
+ outText = mapChars (inText, charMap)
+ assert outText == expectText
+
+def test_map_chars_not_mapped ():
+ """ No chars not in the map should be mapped """
+
+ inText = ''
+ expectText = ''
+ for k, v in charMap.items ():
+ inText += v
+ expectText += v
+ inText += 'a'
+ expectText += 'a'
+
+ outText = mapChars (inText, charMap)
+ assert outText == expectText
+
diff --git a/lulua/text.py b/lulua/text.py
index 2d8398d..1d46af8 100644
--- a/lulua/text.py
+++ b/lulua/text.py
@@ -37,7 +37,7 @@ from html5lib.filters.base import Filter
from .keyboard import defaultKeyboards
from .layout import defaultLayouts
from .writer import Writer
-from .stats import allStats
+from .stats import allStats, makeCombined
def iterchar (fd):
batchsize = 1*1024*1024
@@ -244,32 +244,28 @@ charMap = {
'\u00a0': ' ',
}
-def writeWorker (args, inq, outq):
+def mapChars (text, m):
+ """ For all characters in text, replace if found in map m or keep as-is """
+ return ''.join (map (lambda x: m.get (x, x), text))
+
+def writeWorker (layout, sourceFunc, inq, outq):
try:
keyboard = defaultKeyboards['ibmpc105']
- layout = defaultLayouts['null'].specialize (keyboard)
- w = Writer (layout)
- combined = dict ((cls.name, cls(w)) for cls in allStats)
+ combined = makeCombined (keyboard)
itemsProcessed = 0
while True:
- keyboard = defaultKeyboards[args.keyboard]
- layout = defaultLayouts[args.layout].specialize (keyboard)
- w = Writer (layout)
-
item = inq.get ()
if item is None:
break
# extract (can be multiple items per source)
- for text in sources[args.source] (item):
- text = ''.join (map (lambda x: charMap.get (x, x), text))
- # XXX sanity checks, disable
- for c in charMap.keys ():
- if c in text:
- #print (c, 'is in text', file=sys.stderr)
- assert False, c
+ for text in sourceFunc (item):
+ # map chars
+ text = mapChars (text, charMap)
+ # init a new writer for every item
+ w = Writer (layout)
# stats
stats = [cls(w) for cls in allStats]
for match, event in w.type (StringIO (text)):
@@ -309,6 +305,9 @@ def write ():
else:
logging.basicConfig (level=logging.INFO)
+ keyboard = defaultKeyboards[args.keyboard]
+ layout = defaultLayouts[args.layout].specialize (keyboard)
+
# limit queue sizes to limit memory usage
inq = Queue (args.jobs*2)
outq = Queue (args.jobs+1)
@@ -316,7 +315,10 @@ def write ():
logging.info (f'using {args.jobs} workers')
workers = []
for i in range (args.jobs):
- p = Process(target=writeWorker, args=(args, inq, outq), daemon=True, name=f'worker-{i}')
+ p = Process(target=writeWorker,
+ args=(layout, sources[args.source], inq, outq),
+ daemon=True,
+ name=f'worker-{i}')
p.start()
workers.append (p)
diff --git a/lulua/writer.py b/lulua/writer.py
index bbe1efb..94ad1b4 100644
--- a/lulua/writer.py
+++ b/lulua/writer.py
@@ -20,6 +20,7 @@
import json
from operator import itemgetter
+from typing import Text
from .layout import *
@@ -109,7 +110,8 @@ defaultFingermap = {
class SkipEvent:
__slots__ = ('char', )
- def __init__ (self, char):
+ def __init__ (self, char: Text):
+ assert len (char) == 1
self.char = char
def __eq__ (self, other):