diff options
| -rw-r--r-- | lulua/stats.py | 46 | ||||
| -rw-r--r-- | lulua/test_stats.py | 88 | ||||
| -rw-r--r-- | lulua/test_text.py | 48 | ||||
| -rw-r--r-- | lulua/text.py | 36 | ||||
| -rw-r--r-- | lulua/writer.py | 4 | 
5 files changed, 200 insertions, 22 deletions
| diff --git a/lulua/stats.py b/lulua/stats.py index a7980d6..4d455c5 100644 --- a/lulua/stats.py +++ b/lulua/stats.py @@ -50,6 +50,12 @@ def updateDictOp (a, b, op):  class Stats:      name = 'invalid' +    def process (self, event): +        raise NotImplementedError + +    def update (self, other): +        raise NotImplementedError +  class RunlenStats (Stats):      __slots__ = ('lastHand', 'perHandRunlenDist', 'curPerHandRunlen',              'fingerRunlen', 'lastFinger', 'fingerRunlenDist', 'writer') @@ -67,6 +73,11 @@ class RunlenStats (Stats):          self.fingerRunlenDist = dict (((x, y), defaultdict (int)) for x, y in product (iter (Direction), iter (FingerType)))          self.fingerRunlen = 0 +    def __eq__ (self, other): +        if not isinstance (other, RunlenStats): +            return NotImplemented +        return self.perHandRunlenDist == other.perHandRunlenDist +              def process (self, event):          if isinstance (event, ButtonCombination):              assert len (event.buttons) == 1 @@ -90,6 +101,8 @@ class RunlenStats (Stats):              self.lastFinger = None              self.fingerRunlen = 0 +        else: +            raise ValueError ()      def update (self, other):          updateDictOp (self.perHandRunlenDist, other.perHandRunlenDist, operator.add) @@ -106,6 +119,13 @@ class SimpleStats (Stats):          self.combinations = defaultdict (int)          self.unknown = defaultdict (int) +    def __eq__ (self, other): +        if not isinstance (other, SimpleStats): +            return NotImplemented +        return self.buttons == other.buttons and \ +                self.combinations == other.combinations and \ +                self.unknown == other.unknown +      def process (self, event):          if isinstance (event, SkipEvent):              self.unknown[event.char] += 1 @@ -113,6 +133,8 @@ class SimpleStats (Stats):              for b in event:                  self.buttons[b] += 1              self.combinations[event] += 1 +        else: +            raise ValueError ()      def update (self, other):          updateDictOp (self.buttons, other.buttons, operator.add) @@ -138,6 +160,11 @@ class TriadStats (Stats):          keyboard = self._writer.layout.keyboard          self._ignored = frozenset (keyboard[x] for x in ('Fl_space', 'Fr_space', 'CD_ret', 'Cl_tab')) +    def __eq__ (self, other): +        if not isinstance (other, TriadStats): +            return NotImplemented +        return self.triads == other.triads +      def process (self, event):          if isinstance (event, SkipEvent):              # reset @@ -154,6 +181,8 @@ class TriadStats (Stats):                  if len (self._triad) == 3:                      k = tuple (self._triad)                      self.triads[k] += 1 +        else: +            raise ValueError ()      def update (self, other):          updateDictOp (self.triads, other.triads, operator.add) @@ -173,6 +202,11 @@ class WordStats (Stats):          self._currentWord = []          self.words = defaultdict (int) +    def __eq__ (self, other): +        if not isinstance (other, WordStats): +            return NotImplemented +        return self.words == other.words +      def process (self, event):          if isinstance (event, SkipEvent):              # reset @@ -188,6 +222,8 @@ class WordStats (Stats):                  elif self._currentWord:                      self.words[''.join (self._currentWord)] += 1                      self._currentWord = [] +        else: +            raise ValueError ()      def update (self, other):          updateDictOp (self.words, other.words, operator.add) @@ -201,11 +237,15 @@ def unpickleAll (fd):          except EOFError:              break -def combine (args): -    keyboard = defaultKeyboards[args.keyboard] +def makeCombined (keyboard): +    """ Create a dict which contains initialized stats, ready for combining (not actual writing!) """      layout = defaultLayouts['null'].specialize (keyboard)      w = Writer (layout) -    combined = dict ((cls.name, cls(w)) for cls in allStats) +    return dict ((cls.name, cls(w)) for cls in allStats) + +def combine (args): +    keyboard = defaultKeyboards[args.keyboard] +    combined = makeCombined (keyboard)      for r in unpickleAll (sys.stdin.buffer):          for s in allStats:              combined[s.name].update (r[s.name]) diff --git a/lulua/test_stats.py b/lulua/test_stats.py index 9e3ed77..3259bcd 100644 --- a/lulua/test_stats.py +++ b/lulua/test_stats.py @@ -21,7 +21,10 @@  import operator  import pytest -from .stats import updateDictOp, approx +from .stats import updateDictOp, approx, SimpleStats, TriadStats, allStats +from .keyboard import defaultKeyboards +from .layout import defaultLayouts, ButtonCombination +from .writer import Writer, SkipEvent  def test_updateDictOp ():      a = {1: 3} @@ -50,3 +53,86 @@ def test_approx ():      assert approx (10**9) == (1, 0, 'billion')      assert approx (10**12) == (1000, 0, 'billion') +@pytest.fixture +def writer (): +    """ Return a default, safe writer with known properties for a fixed layout """ +    keyboard = defaultKeyboards['ibmpc105'] +    layout = defaultLayouts['ar-lulua'].specialize (keyboard) +    return Writer (layout) +     +def test_simplestats (writer): +    keyboard = writer.layout.keyboard + +    s = SimpleStats (writer) +    assert not s.unknown +    assert not s.combinations +    assert not s.buttons + +    s.process (SkipEvent ('a')) +    assert len (s.unknown) == 1 and s.unknown['a'] == 1 +    # no change for those +    assert not s.combinations +    assert not s.buttons + +    dlcaps = keyboard['Dl_caps'] +    bl1 = keyboard['Bl1'] +    comb = ButtonCombination (frozenset ([dlcaps]), frozenset ([bl1])) +    s.process (comb) +    assert s.buttons[dlcaps] == 1 +    assert s.buttons[bl1] == 1 +    assert s.combinations[comb] == 1 +    # no change +    assert len (s.unknown) == 1 and s.unknown['a'] == 1 + +    s2 = SimpleStats (writer) +    s2.update (s) +    assert s2 == s + +def test_triadstats (writer): +    keyboard = writer.layout.keyboard + +    s = TriadStats (writer) +    assert not s.triads + +    s.process (SkipEvent ('a')) +    # should not change anything +    assert not s.triads + +    dlcaps = keyboard['Dl_caps'] +    bl1 = keyboard['Bl1'] +    comb = ButtonCombination (frozenset ([dlcaps]), frozenset ([bl1])) +    for i in range (3): +        s.process (comb) +    assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 1 + +    # sliding window -> increase +    s.process (comb) +    assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2 + +    # clear sliding window +    s.process (SkipEvent ('a')) +    assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2 + +    # thus no change here +    for i in range (2): +        s.process (comb) +    assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 2 + +    # but here +    s.process (comb) +    assert len (s.triads) == 1 and s.triads[(comb, comb, comb)] == 3 + +def test_stats_process_value (writer): +    """ Make sure stats classes reject invalid values for .process() """ + +    for cls in allStats: +        s = cls (writer) +        with pytest.raises (ValueError): +            s.process (1) + +        s.process (SkipEvent ('a')) +        s2 = cls (writer) +        s2.update (s) +        assert s2 == s +        assert not s2 == 1 + diff --git a/lulua/test_text.py b/lulua/test_text.py new file mode 100644 index 0000000..65aa3a1 --- /dev/null +++ b/lulua/test_text.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 lulua contributors +#  +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +#  +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +#  +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from .text import charMap, mapChars + +def test_map_chars_mapped (): +    """ Make sure all chars in the map are mapped correctly """ + +    inText = '' +    expectText = '' +    for k, v in charMap.items (): +        inText += k +        expectText += v + +    outText = mapChars (inText, charMap) +    assert outText == expectText + +def test_map_chars_not_mapped (): +    """ No chars not in the map should be mapped """ + +    inText = '' +    expectText = '' +    for k, v in charMap.items (): +        inText += v +        expectText += v +    inText += 'a' +    expectText += 'a' + +    outText = mapChars (inText, charMap) +    assert outText == expectText + diff --git a/lulua/text.py b/lulua/text.py index 2d8398d..1d46af8 100644 --- a/lulua/text.py +++ b/lulua/text.py @@ -37,7 +37,7 @@ from html5lib.filters.base import Filter  from .keyboard import defaultKeyboards  from .layout import defaultLayouts  from .writer import Writer -from .stats import allStats +from .stats import allStats, makeCombined  def iterchar (fd):      batchsize = 1*1024*1024 @@ -244,32 +244,28 @@ charMap = {      '\u00a0': ' ',      } -def writeWorker (args, inq, outq): +def mapChars (text, m): +    """ For all characters in text, replace if found in map m or keep as-is """ +    return ''.join (map (lambda x: m.get (x, x), text)) + +def writeWorker (layout, sourceFunc, inq, outq):      try:          keyboard = defaultKeyboards['ibmpc105'] -        layout = defaultLayouts['null'].specialize (keyboard) -        w = Writer (layout) -        combined = dict ((cls.name, cls(w)) for cls in allStats) +        combined = makeCombined (keyboard)          itemsProcessed = 0          while True: -            keyboard = defaultKeyboards[args.keyboard] -            layout = defaultLayouts[args.layout].specialize (keyboard) -            w = Writer (layout) -              item = inq.get ()              if item is None:                  break              # extract (can be multiple items per source) -            for text in  sources[args.source] (item): -                text = ''.join (map (lambda x: charMap.get (x, x), text)) -                # XXX sanity checks, disable -                for c in charMap.keys (): -                    if c in text: -                        #print (c, 'is in text', file=sys.stderr) -                        assert False, c +            for text in sourceFunc (item): +                # map chars +                text = mapChars (text, charMap) +                # init a new writer for every item +                w = Writer (layout)                  # stats                  stats = [cls(w) for cls in allStats]                  for match, event in w.type (StringIO (text)): @@ -309,6 +305,9 @@ def write ():      else:          logging.basicConfig (level=logging.INFO) +    keyboard = defaultKeyboards[args.keyboard] +    layout = defaultLayouts[args.layout].specialize (keyboard) +      # limit queue sizes to limit memory usage      inq = Queue (args.jobs*2)      outq = Queue (args.jobs+1) @@ -316,7 +315,10 @@ def write ():      logging.info (f'using {args.jobs} workers')      workers = []      for i in range (args.jobs): -        p = Process(target=writeWorker, args=(args, inq, outq), daemon=True, name=f'worker-{i}') +        p = Process(target=writeWorker, +                args=(layout, sources[args.source], inq, outq), +                daemon=True, +                name=f'worker-{i}')          p.start()          workers.append (p) diff --git a/lulua/writer.py b/lulua/writer.py index bbe1efb..94ad1b4 100644 --- a/lulua/writer.py +++ b/lulua/writer.py @@ -20,6 +20,7 @@  import json  from operator import itemgetter +from typing import Text  from .layout import * @@ -109,7 +110,8 @@ defaultFingermap = {  class SkipEvent:      __slots__ = ('char', ) -    def __init__ (self, char): +    def __init__ (self, char: Text): +        assert len (char) == 1          self.char = char      def __eq__ (self, other): | 
