소스 검색

Use ExitStack and decorators for read transactions; consolidate database and index lists;
a few initial tests.

Stefano Cossu 7 년 전
부모
커밋
0a2f70a6c0
3개의 변경된 파일287개의 추가작업 그리고 172개의 파일을 삭제
  1. 174 143
      lakesuperior/store_layouts/ldp_rs/lmdb_store.py
  2. 0 29
      tests/store/lmdb_store.py
  3. 113 0
      tests/store/test_lmdb_store.py

+ 174 - 143
lakesuperior/store_layouts/ldp_rs/lmdb_store.py

@@ -1,15 +1,15 @@
 import hashlib
 import logging
 
+from contextlib import ExitStack
 from os import makedirs
 from os.path import exists, abspath
+from urllib.request import pathname2url
 
 import lmdb
 
 from rdflib.store import Store, VALID_STORE, NO_STORE
 from rdflib.term import URIRef
-from six import b
-from six.moves.urllib.request import pathname2url
 
 
 logger = logging.getLogger(__name__)
@@ -30,6 +30,32 @@ class NoTxnError(Exception):
         return 'No transaction active in the store.'
 
 
+def read_tx(dbs=(), buffers=True):
+    '''
+    Decorator to wrap a method into a read transaction.
+
+    This method creates the necessary cursors indicated in the `db` parameter.
+
+    @param dbs (tuple|list:string) Database label(s) to open cursors. No
+    cursors are automatically opened by default.
+    '''
+    def read_tx_deco(fn):
+        def wrapper(self, *args, **kwargs):
+            with ExitStack() as stack:
+                self.rtxn = stack.enter_context(
+                        self.db_env.begin(buffers=buffers))
+                self.rcurs = {}
+                for db_label in dbs:
+                    self.rcurs[db_label] = stack.enter_context(
+                            self.rtxn.cursor(self.dbs[db_label]))
+                stack.pop_all()
+                ret = fn(self, *args, **kwargs)
+
+                return ret
+        return wrapper
+    return read_tx_deco
+
+
 class LmdbStore(Store):
     '''
     LMDB-backed store.
@@ -94,7 +120,7 @@ class LmdbStore(Store):
 
     db_env = None
     db = None
-    idx_db = {}
+    dbs = {}
     txn = None
 
 
@@ -121,15 +147,18 @@ class LmdbStore(Store):
                 max_dbs=12, readahead=False)
 
         # Open and optionally create main databases.
-        self.trp_db = self.db_env.open_db(b'tk:t', create=create)
-        self.ctx_db = self.db_env.open_db(b'tk:c', create=create, dupsort=True)
-        self.pfx_db = self.db_env.open_db(b'pfx:ns', create=create)
-        # Index databases.
+        self.dbs = {
+            # Main databases.
+            'tk:t': self.db_env.open_db(b'tk:t', create=create),
+            'tk:c': self.db_env.open_db(b'tk:c', create=create, dupsort=True),
+            'pfx:ns': self.db_env.open_db(b'pfx:ns', create=create),
+            # Index.
+            'ns:pfx': self.db_env.open_db(b'ns:pfx', create=create),
+        }
+        # Other index databases.
         for db_key in self.idx_keys:
-            self.idx_db[db_key] = self.db_env.open_db(s2b(db_key),
+            self.dbs[db_key] = self.db_env.open_db(s2b(db_key),
                     dupsort=True, dupfixed=True, create=create)
-        self.idx_db['ns:pfx'] = self.db_env.open_db(
-                b'ns:pfx', create=create)
 
 
     @property
@@ -165,10 +194,10 @@ class LmdbStore(Store):
         '''
         if not self.is_open:
             raise RuntimeError('Store must be opened first.')
-        self.txn = self.db_env.begin(write=True, buffers=True)
+        self.wtxn = self.db_env.begin(write=True, buffers=True)
         # Cursors.
-        self.data_cur = self.get_data_cursors(self.txn)
-        self.idx_cur = self.get_idx_cursors(self.txn)
+        self.wcurs = self.get_data_cursors(self.wtxn)
+        self.wcurs.update(self.get_idx_cursors(self.wtxn))
 
 
     def get_data_cursors(self, txn):
@@ -181,9 +210,9 @@ class LmdbStore(Store):
         index cursors.
         '''
         return {
-            'tk:t': txn.cursor(self.trp_db),
-            'tk:c': txn.cursor(self.ctx_db),
-            'pfx:ns': txn.cursor(self.pfx_db),
+            'tk:t': txn.cursor(self.dbs['tk:t']),
+            'tk:c': txn.cursor(self.dbs['tk:c']),
+            'pfx:ns': txn.cursor(self.dbs['ns:pfx']),
         }
 
 
@@ -198,7 +227,7 @@ class LmdbStore(Store):
         '''
         cur = {}
         for key in self.idx_keys:
-            cur[key] = self.txn.cursor(self.idx_db[key])
+            cur[key] = self.wtxn.cursor(self.dbs[key])
 
         return cur
 
@@ -209,7 +238,7 @@ class LmdbStore(Store):
         Whether the main write transaction is open.
         '''
         try:
-            self.txn.id()
+            self.wtxn.id()
         except lmdb.Error:
             return False
         else:
@@ -249,10 +278,10 @@ class LmdbStore(Store):
         pk_trp = self._pickle(triple)
         trp_key = hashlib.new(self.KEY_HASH_ALGO, pk_trp).digest()
         # If it returns False, the triple had already been added.
-        trp_added = self.data_cur['tk:t'].put(trp_key, pk_trp, overwrite=False)
+        trp_added = self.wcurs['tk:t'].put(trp_key, pk_trp, overwrite=False)
 
         pk_ctx = self._pickle(context)
-        ctx_added = self.data_cur['tk:c'].put(trp_key, pk_ctx, overwrite=False)
+        ctx_added = self.wcurs['tk:c'].put(trp_key, pk_ctx, overwrite=False)
 
         if ctx_added or trp_added:
             # @TODO make await
@@ -270,14 +299,14 @@ class LmdbStore(Store):
             need_indexing = False
 
             # Delete context association.
-            if self.data_cur['tk:c'].set_key_dup(pk_ctx, trp_key):
-                self.data_cur['tk:c'].delete()
+            if self.wcurs['tk:c'].set_key_dup(trp_key, pk_ctx):
+                self.wcurs['tk:c'].delete()
                 need_indexing = True
 
                 # If no other contexts are associated w/ the triple, delete it.
-                if not self.data_cur['tk:c'].set_key(trp_key) and (
-                        self.data_cur['tk:t'].set_key(trp_key)):
-                    self.data_cur['tk:t'].delete()
+                if not self.wcurs['tk:c'].set_key(trp_key) and (
+                        self.wcurs['tk:t'].set_key(trp_key)):
+                    self.wcurs['tk:t'].delete()
 
                 # @TODO make await
                 self._do_index(trp, trp_key, pk_ctx)
@@ -304,89 +333,77 @@ class LmdbStore(Store):
             'pok:tk': self._to_key((p, o)),
         }
 
-        if self.data_cur['tk:t'].get(trp_key):
+        if self.wcurs['tk:t'].get(trp_key):
             # Add to index.
             for ikey in term_keys:
-                self.idx_cur[ikey].put(term_keys[ikey], trp_key)
+                self.wcurs[ikey].put(term_keys[ikey], trp_key)
         else:
             # Delete from index if a match is found.
-            for ikey in self.term_keys:
-                if self.idx_cur[ikey].set_key_dup(term_keys[ikey], trp_key):
-                    self.idx_cur[ikey].delete()
+            for ikey in term_keys:
+                if self.wcurs[ikey].set_key_dup(term_keys[ikey], trp_key):
+                    self.wcurs[ikey].delete()
 
         # Add or remove context association index.
-        if self.data_cur['tk:c'].get(trp_key, pk_ctx):
-            self.idx_cur['c:tk'].put(pk_ctx, trp_key)
+        if self.wcurs['tk:c'].get(trp_key, pk_ctx):
+            self.wcurs['c:tk'].put(pk_ctx, trp_key)
         else:
-            if self.idx_cur['c:tk'].set_key_dup(pk_ctx, trp_key):
-                self.idx_cur['c:tk'].delete()
+            if self.wcurs['c:tk'].set_key_dup(pk_ctx, trp_key):
+                self.wcurs['c:tk'].delete()
 
 
+    @read_tx((
+        'sk:tk', 'pk:tk', 'ok:tk', 'spk:tk', 'sok:tk', 'pok:tk',
+        'c:tk', 'tk:c', 'tk:t'))
     def triples(self, triple_pattern, context=None):
         '''
         Generator over matching triples.
         '''
-        assert self.__open, "The Store must be open."
         if context == self:
             context = None
+        context = context or self.DEFAULT_GRAPH_URI
 
         tkey = self._to_key(triple_pattern)
 
-        with self.db_env.begin(buffers=True) as txn:
-            if context is not None:
-                pk_ctx = self._pickle(context)
-                if not self.idx_cur['c:tk'].set_key(pk_ctx):
-                    # Context not found.
-                    return iter(())
-                # If all triple elements are bound
-                if all(triple_pattern):
-                    with txn.cursor(self.ctx_db) as cur:
-                        if cur.set_key_dup(tkey, pk_ctx):
-                            yield self._key_to_triple(tkey)
-                        else:
-                            # Triple not found.
-                            return iter(())
-                # If some are unbound
+        # Any pattern with unbound context
+        if context == self.DEFAULT_GRAPH_URI:
+            for tk in self._lookup(triple_pattern, tkey):
+                yield self._key_to_triple(tk)
+
+        # Shortcuts
+        else:
+            pk_ctx = self._pickle(context)
+            if not self.rcurs['c:tk'].set_key(pk_ctx):
+                # Context not found.
+                return iter(())
+
+            # s p o c
+            if all(triple_pattern):
+                if self.rcurs['tk:c'].set_key_dup(tkey, pk_ctx):
+                    yield self._key_to_triple(tkey)
+                    return
                 else:
-                    # If some are bound
-                    if any(triple_pattern):
-                        # Find the lookup index
-                        with txn.cursor(self.idx_db['c:tk']) as cur:
-                            for tk in self._lookup(triple_pattern):
-                                if cur.set_key_dup(pk_ctx, tk):
-                                    yield self._key_to_triple(tk)
-                    # If all are unbound
-                    else:
-                        # Get all triples from the context
-                        with txn.cursor(self.idx_db['c:tk']) as cur:
-                            for tk in cur.iternext_dup():
-                                yield self._key_to_triple(tk)
-            # If context is unbound
+                    # Triple not found.
+                    return iter(())
+
+            # ? ? ? c
+            elif not any(triple_pattern):
+                # Get all triples from the context
+                for tk in self.rcurs['c:tk'].iternext_dup():
+                    yield self._key_to_triple(tk)
+
             else:
-                # If all triples are bound
-                if all(triple_pattern):
-                    with txn.cursor(self.trp_db) as cur:
-                        match = cur.set_key(tkey)
-                        if match:
-                            yield self._key_to_triple(match)
-                        else:
-                            return iter(())
-                # If some are unbound
-                else:
-                    # If some are bound
-                    if any(triple_pattern):
-                        return self._lookup(triple_pattern)
-                    # If all are unbound
-                    else:
-                        # Get all triples in the database
-                        with txn.cursor(self.trp_db) as cur:
-                            pk_triples = cur.iternext(keys=False)
-                            for pk_trp in pk_triples:
-                                yield self._unpickle(pk_trp)
+                # Regular lookup.
+                for tk in self._lookup(triple_pattern, tkey):
+                    if self.rcurs['c:tk'].set_key_dup(pk_ctx, tk):
+                        yield self._key_to_triple(tk)
 
 
+
+    @read_tx()
     def __len__(self, context=None):
-        assert self.__open, "The Store must be open."
+        '''
+        Return length of the dataset.
+        '''
         if context == self:
             context = None
 
@@ -394,8 +411,7 @@ class LmdbStore(Store):
             dataset = self.triples((None, None, None), context)
             return len(set(dataset))
         else:
-            with self.environment.begin() as txn:
-                return txn.stat(self.trp_db)['entries']
+            return self.rtxn.stat(self.dbs['tk:t'])['entries']
 
 
     def bind(self, prefix, namespace):
@@ -404,21 +420,23 @@ class LmdbStore(Store):
         '''
         prefix = s2b(prefix)
         namespace = s2b(namespace)
-        with self.txn.cursor(self.idx_db(b'ns:pfx')) as cur:
+        with self.wtxn.cursor(self.dbs(b'ns:pfx')) as cur:
             cur.put(namespace, prefix)
-        with self.txn.cursor(self.idx_db(b'pfx:ns')) as cur:
+        with self.wtxn.cursor(self.dbs(b'pfx:ns')) as cur:
             cur.put(prefix, namespace)
 
 
+    @read_tx(('pfx:ns',))
     def namespace(self, prefix):
         '''
         Get the namespace for a prefix.
         '''
-        ns = self.idx_cur['pfx:ns'].get(s2b(prefix))
+        ns = self.rcurs['pfx:ns'].get(s2b(prefix))
 
         return URIRef(b2s(ns)) if ns is not None else None
 
 
+    @read_tx(('ns:pfx',))
     def prefix(self, namespace):
         '''
         Get the prefix associated with a namespace.
@@ -426,32 +444,31 @@ class LmdbStore(Store):
         @NOTE A namespace can be only bound to one prefix in this
         implementation.
         '''
-        prefix = self.data_cur['ns:pfx'].get(s2b(namespace))
+        prefix = self.rcurs['pfx:ns'].get(s2b(namespace))
 
         return b2s(prefix) if prefix is not None else None
 
 
+    @read_tx(('ns:pfx',))
     def namespaces(self):
         '''
         Get a dict of all prefix: namespace bindings.
         '''
-        with self.tx.cursor(self.pfx_db) as cur:
-            bindings = iter(cur)
+        bindings = iter(self.rcurs['ns:pfx'])
 
         return ((b2s(pfx), b2s(ns)) for pfx, ns in bindings)
 
 
+    @read_tx(('tk:c','c:tk'))
     def contexts(self, triple=None):
         '''
         Get a list of all contexts.
         '''
         if triple:
-            with self.tx.cursor(self.ctx_db) as cur:
-                cur.set_key(self._to_key(triple))
-                contexts = cur.iternext_dup()
+            self.rcurs['tk:c'].set_key(self._to_key(triple))
+            contexts = self.rcurs['tk:c'].iternext_dup()
         else:
-            with self.tx.cursor(self.idx_db[b'c:tk']) as cur:
-                contexts = cur.iternext_nodup()
+            contexts = self.rcurs['c:tk'].iternext_nodup()
 
         return (b2s(ctx) for ctx in contexts)
 
@@ -460,8 +477,8 @@ class LmdbStore(Store):
         '''
         Add a graph to the database.
         '''
-        self.data_cur['tk:c'].put(self._pickle(None), self._pickle(graph))
-        self.idx_cur['c:tk'].put(self._pickle(graph), self._pickle(None))
+        self.wcurs['tk:c'].put(self._pickle(None), self._pickle(graph))
+        self.wcurs['c:tk'].put(self._pickle(graph), self._pickle(None))
 
 
     def remove_graph(self, graph):
@@ -470,27 +487,27 @@ class LmdbStore(Store):
         '''
         self.remove((None, None, None), graph)
 
-        if self.data_cur['tk:c'].set_key_dup(
+        if self.wcurs['tk:c'].set_key_dup(
                 self._pickle(None), self._pickle(graph)):
-            self.data_cur['tk:c'].delete()
+            self.wcurs['tk:c'].delete()
 
-        if self.idx_cur['c:tk'].set_key_dup(
+        if self.wcurs['c:tk'].set_key_dup(
                 self._pickle(graph), self._pickle(None)):
-            self.data_cur['tk:c'].delete()
+            self.wcurs['tk:c'].delete()
 
 
     def commit(self):
         '''
-        Commit main transaction.
+        Commit main write transaction.
         '''
-        self.txn.commit()
+        self.wtxn.commit()
 
 
     def rollback(self):
         '''
-        Roll back main transaction.
+        Roll back main write transaction.
         '''
-        self.txn.abort()
+        self.wtxn.abort()
 
 
     #def _next_lex_key(self, db=None):
@@ -520,6 +537,8 @@ class LmdbStore(Store):
     #    return next
 
 
+    ## PRIVATE METHODS ##
+
     def _to_key(self, obj):
         '''
         Convert a triple, quad or term into a key.
@@ -544,55 +563,67 @@ class LmdbStore(Store):
 
         @return Tuple with triple elements or None if key is not found.
         '''
-        pk_trp = self.data_cur['tk:t'].get(key)
+        pk_trp = self.rcurs['tk:t'].get(key)
 
         return self._unpickle(pk_trp) if pk_trp else None
 
 
-    def _lookup(self, triple_pattern):
+    def _lookup(self, triple_pattern, tkey=None):
         '''
         Look up triples based on a triple pattern.
 
-        This is only used if one or two terms are nubound. If all terms are
-        either bound  or unbound, other methods should be used.
-
         @return iterator of matching triple keys.
         '''
-        if not any(triple_pattern) or all(triple_pattern):
-            raise ValueError(
-                    'This method is not usable with a triple with only '
-                    'unbound or only bound terms.')
-
         s, p, o = triple_pattern
 
-        with self.env.begin(buffers=True) as txn:
-            if s is None:
-                if p is None:
-                    cursor = self.txn.cursor(self.idx_db['o:tk'])
-                    term = self._pickle(o)
-                else:
-                    cursor = self.txn.cursor(self.idx_db['po:tk'])
-                    term = self._pickle((p, o))
-            if p is None:
-                if o is None:
-                    cursor = self.txn.cursor(self.idx_db['s:tk'])
-                    term = self._pickle(s)
+        if s is not None:
+            if p is not None:
+                # s p o
+                if o is not None:
+                    if self.rcurs['tk:t'].set_key(tkey):
+                        yield tkey
+                        return
+                    else:
+                        return iter(())
+                # s p ?
                 else:
-                    cursor = self.txn.cursor(self.idx_db['so:tk'])
+                    cur = self.rcurs['spk:tk']
+                    term = self._pickle((s, p))
+            else:
+                # s ? o
+                if o is not None:
+                    cur = self.rcurs['sok:tk']
                     term = self._pickle((s, o))
-            if o is None:
-                if s is None:
-                    cursor = self.txn.cursor(self.idx_db['p:tk'])
+                # s ? ?
+                else:
+                    cur = self.rcurs['sk:tk']
                     term = self._pickle(s)
+        else:
+            if p is not None:
+                # ? p o
+                if o is not None:
+                    cur = self.rcurs['pok:tk']
+                    term = self._pickle((p, o))
+                # ? p ?
                 else:
-                    cursor = self.txn.cursor(self.idx_db['sp:tk'])
-                    term = self._pickle((s, p))
-
-            key = hashlib.new(self.KEY_HASH_ALGO, term).digest()
-            with cursor as cur:
-                if cur.set_key(key):
-                    for match in cur.iternext_dup():
-                        yield match
+                    cur = self.rcurs['pk:tk']
+                    term = self._pickle(p)
+            else:
+                # ? ? o
+                if o is not None:
+                    cur = self.rcurs['ok:tk']
+                    term = self._pickle(o)
+                # ? ? ?
                 else:
-                    return iter(())
+                    # Get all triples in the database
+                    for c in self.rcurs['tk:t'].iternext(values=False):
+                        yield c
+                    return
+
+        key = hashlib.new(self.KEY_HASH_ALGO, term).digest()
+        if cur.set_key(key):
+            for match in cur.iternext_dup():
+                yield match
+        else:
+            return iter(())
 

+ 0 - 29
tests/store/lmdb_store.py

@@ -1,29 +0,0 @@
-import pytest
-
-from rdflib import URIRef
-
-from lakesuperior.store_layouts.ldp_rs.lmdb_store import LmdbStore
-
-@pytest.fixture(scope='module')
-def store():
-    return LmdbStore('/tmp/lmdbstore')
-
-
-@pytest.mark.usefixtures('store')
-class TestLmdbStore:
-    '''
-    Unit tests for LMDB store.
-    '''
-    def test_create_triple(self, store):
-        '''
-        Test creation of a single triple.
-        '''
-        store.begin()
-        store.add((
-            URIRef('urn:test:s'), URIRef('urn:test:p'), URIRef('urn:test:o')))
-        store.commit()
-
-        res = set(store.triples((None, None, None)))
-        assert len(res) == 1
-        assert (URIRef('urn:test:s'), URIRef('urn:test:p'), URIRef('urn:test:o')) \
-                in res

+ 113 - 0
tests/store/test_lmdb_store.py

@@ -0,0 +1,113 @@
+import pytest
+
+from shutil import rmtree
+
+from rdflib import URIRef
+
+from lakesuperior.store_layouts.ldp_rs.lmdb_store import LmdbStore
+
+@pytest.fixture(scope='class')
+def store():
+    store = LmdbStore('/tmp/test_lmdbstore')
+    yield store
+    store.close()
+    rmtree('/tmp/test_lmdbstore')
+
+
+@pytest.mark.usefixtures('store')
+class TestLmdbStore:
+    '''
+    Unit tests for LMDB store.
+    '''
+    def test_create_triple(self, store):
+        '''
+        Test creation of a single triple.
+        '''
+        store.begin()
+        store.add((
+            URIRef('urn:test:s'), URIRef('urn:test:p'), URIRef('urn:test:o')))
+        store.commit()
+
+        res1 = set(store.triples((None, None, None)))
+        res2 = set(store.triples((
+            URIRef('urn:test:s'), URIRef('urn:test:p'), URIRef('urn:test:o'))))
+        assert len(res1) == 1
+        assert len(res2) == 1
+        assert (
+            URIRef('urn:test:s'), URIRef('urn:test:p'),
+            URIRef('urn:test:o')) in res1 & res2
+
+
+    def test_triple_match_1bound(self, store):
+        '''
+        Test triple patterns matching one bound term.
+        '''
+        res1 = set(store.triples((URIRef('urn:test:s'), None, None)))
+        res2 = set(store.triples((None, URIRef('urn:test:p'), None)))
+        res3 = set(store.triples((None, None, URIRef('urn:test:o'))))
+        assert res1 == {(
+            URIRef('urn:test:s'), URIRef('urn:test:p'), URIRef('urn:test:o'))}
+        assert res2 == res1
+        assert res3 == res2
+
+
+    def test_triple_match_2bound(self, store):
+        '''
+        Test triple patterns matching two bound terms.
+        '''
+        res1 = set(store.triples(
+            (URIRef('urn:test:s'), URIRef('urn:test:p'), None)))
+        res2 = set(store.triples(
+            (URIRef('urn:test:s'), None, URIRef('urn:test:o'))))
+        res3 = set(store.triples(
+            (None, URIRef('urn:test:p'), URIRef('urn:test:o'))))
+        assert res1 == {(
+            URIRef('urn:test:s'), URIRef('urn:test:p'), URIRef('urn:test:o'))}
+        assert res2 == res1
+        assert res3 == res2
+
+
+    def test_triple_no_match(self, store):
+        '''
+        Test various mismatches.
+        '''
+        store.begin()
+        store.add((
+            URIRef('urn:test:s'),
+            URIRef('urn:test:p2'), URIRef('urn:test:o2')))
+        store.add((
+            URIRef('urn:test:s3'),
+            URIRef('urn:test:p3'), URIRef('urn:test:o3')))
+        store.commit()
+        res1 = set(store.triples((None, None, None)))
+        assert len(res1) == 3
+
+        res1 = set(store.triples(
+            (URIRef('urn:test:s2'), URIRef('urn:test:p'), None)))
+        res2 = set(store.triples(
+            (URIRef('urn:test:s3'), None, URIRef('urn:test:o'))))
+        res3 = set(store.triples(
+            (None, URIRef('urn:test:p3'), URIRef('urn:test:o2'))))
+
+        assert len(res1) == len(res2) == len(res3) == 0
+
+
+    def test_remove(self, store):
+        '''
+        Test removing one or more triples.
+        '''
+        store.begin()
+        store.remove((URIRef('urn:test:s3'),
+                URIRef('urn:test:p3'), URIRef('urn:test:o3')))
+        store.commit()
+
+        res1 = set(store.triples((None, None, None)))
+        assert len(res1) == 2
+
+        store.begin()
+        store.remove((URIRef('urn:test:s'), None, None))
+        store.commit()
+        res2 = set(store.triples((None, None, None)))
+        assert len(res2) == 0
+
+