From cd57e251c7cda38471c004de7a82bb6c0db1c6ae Mon Sep 17 00:00:00 2001
From: Bartek Wrona <wrona@syncad.com>
Date: Mon, 7 Sep 2020 23:12:47 +0200
Subject: [PATCH] Preliminary implementation of MT flush.

---
 hive/indexer/blocks.py            | 69 ++++++++++++++++++++-----------
 hive/indexer/db_adapter_holder.py | 25 +++++++++++
 hive/indexer/follow.py            | 22 +++++-----
 hive/indexer/post_data_cache.py   | 14 +++++--
 hive/indexer/reputations.py       | 24 ++++++-----
 hive/indexer/sync.py              | 24 +++++------
 hive/indexer/tags.py              | 17 ++++----
 hive/indexer/votes.py             | 14 +++++--
 8 files changed, 136 insertions(+), 73 deletions(-)
 create mode 100644 hive/indexer/db_adapter_holder.py

diff --git a/hive/indexer/blocks.py b/hive/indexer/blocks.py
index 51ce55e61..3a1949096 100644
--- a/hive/indexer/blocks.py
+++ b/hive/indexer/blocks.py
@@ -2,6 +2,7 @@
 
 import logging
 import json
+import concurrent
 
 from hive.db.adapter import Db
 
@@ -13,7 +14,7 @@ from hive.indexer.follow import Follow
 from hive.indexer.votes import Votes
 from hive.indexer.post_data_cache import PostDataCache
 from hive.indexer.tags import Tags
-
+from hive.indexer.reputations import Reputations
 
 from time import perf_counter
 
@@ -22,6 +23,8 @@ from hive.utils.stats import FlushStatusManager as FSM
 from hive.utils.trends import update_hot_and_tranding_for_block_range
 from hive.utils.post_active import update_active_starting_from_posts_on_block
 
+from concurrent.futures import ThreadPoolExecutor
+
 log = logging.getLogger(__name__)
 
 DB = Db.instance()
@@ -30,13 +33,16 @@ class Blocks:
     """Processes blocks, dispatches work, manages `hive_blocks` table."""
     blocks_to_flush = []
     _head_block_date = None
-    _reputations = None
     _current_block_date = None
 
-    def __init__(cls):
-        log.info("Creating a reputations processor")
-        log.info("Built blocks object: {}".format(cls))
+    _concurrent_flush = [
+      ('PostDataCache', PostDataCache.flush, PostDataCache),
+      ('Reputations', Reputations.flush, Reputations),
+      ('Votes', Votes.flush, Votes), 
+      ('Tags', Tags.flush, Tags), 
+    ]
 
+    def __init__(cls):
         head_date = cls.head_date()
         if(head_date == ''):
             cls._head_block_date = None
@@ -45,11 +51,13 @@ class Blocks:
             cls._head_block_date = head_date
             cls._current_block_date = head_date
 
-    @classmethod 
-    def set_reputations_processor(cls, reputations_processor):
-        cls._reputations = reputations_processor
-        assert cls._reputations is not None, "Reputation object is None"
-        log.info("Built reputations object: {}".format(cls._reputations))
+    @classmethod
+    def setup_db_access(self, sharedDbAdapter):
+        PostDataCache.setup_db_access(sharedDbAdapter)
+        Reputations.setup_db_access(sharedDbAdapter)
+        Votes.setup_db_access(sharedDbAdapter)
+        Tags.setup_db_access(sharedDbAdapter)
+        Follow.setup_db_access(sharedDbAdapter)
 
     @classmethod
     def head_num(cls):
@@ -74,7 +82,7 @@ class Blocks:
         Tags.flush()
         Votes.flush()
         Posts.flush()
-        cls._reputations.flush()
+        Reputations.flush()
         block_num = int(block['block_id'][:8], base=16)
         cls.on_live_blocks_processed( block_num, block_num )
         time_end = perf_counter()
@@ -86,8 +94,6 @@ class Blocks:
         """Batch-process blocks; wrapped in a transaction."""
         time_start = OPSM.start()
 
-        log.info("Blocks object: {}".format(cls))
-
         DB.query("START TRANSACTION")
 
         last_num = 0
@@ -113,13 +119,34 @@ class Blocks:
 
         log.info("#############################################################################")
         flush_time = register_time(flush_time, "Blocks", cls._flush_blocks())
-        flush_time = register_time(flush_time, "PostDataCache", PostDataCache.flush())
-        flush_time = register_time(flush_time, "Tags", Tags.flush())
-        flush_time = register_time(flush_time, "Votes", Votes.flush())
+        flush_time = register_time(flush_time, "Posts", Posts.flush())
+
+#        flush_time = register_time(flush_time, "PostDataCache", PostDataCache.flush())
+#        flush_time = register_time(flush_time, "Tags", Tags.flush())
+#        flush_time = register_time(flush_time, "Votes", Votes.flush())
         folllow_items = len(Follow.follow_items_to_flush) + Follow.flush(trx=False)
         flush_time = register_time(flush_time, "Follow", folllow_items)
-        flush_time = register_time(flush_time, "Posts", Posts.flush())
-        flush_time = register_time(flush_time, "Reputations", cls._flush_reputations())
+#        flush_time = register_time(flush_time, "Reputations", cls._flush_reputations())
+
+        completedThreads = 0;
+
+        pool = ThreadPoolExecutor(max_workers = len(cls._concurrent_flush))
+        flush_futures = {pool.submit(f): (description, c) for (description, f, c) in cls._concurrent_flush}
+        for future in concurrent.futures.as_completed(flush_futures):
+            (description, c) = flush_futures[future]
+            completedThreads = completedThreads + 1
+            try:
+                n = future.result()
+                assert not c.tx_active()
+
+                if n > 0:
+                    log.info('%r flush generated %d records' % (description, n))
+            except Exception as exc:
+                log.error('%r generated an exception: %s' % (description, exc))
+                raise exc
+        pool.shutdown()
+
+        assert completedThreads == len(cls._concurrent_flush)
 
         if (not is_initial_sync) and (first_block > -1):
             cls.on_live_blocks_processed( first_block, last_num )
@@ -162,7 +189,7 @@ class Blocks:
             elif op_type == 'effective_comment_vote_operation':
                 key_vote = "{}/{}/{}".format(op_value['voter'], op_value['author'], op_value['permlink'])
 
-                cls._reputations.process_vote(block_num, op_value)
+                Reputations.process_vote(block_num, op_value)
 
                 vote_ops[ key_vote ] = op_value
 
@@ -339,10 +366,6 @@ class Blocks:
             'date': block['timestamp']})
         return num
 
-    @classmethod
-    def _flush_reputations(cls):
-        return cls._reputations.flush()
-
     @classmethod
     def _flush_blocks(cls):
         query = """
diff --git a/hive/indexer/db_adapter_holder.py b/hive/indexer/db_adapter_holder.py
new file mode 100644
index 000000000..db720051e
--- /dev/null
+++ b/hive/indexer/db_adapter_holder.py
@@ -0,0 +1,25 @@
+import logging
+log = logging.getLogger(__name__)
+
+class DbAdapterHolder(object):
+    db = None
+
+    _inside_tx = False
+
+    @classmethod
+    def setup_db_access(self, sharedDb):
+        self.db = sharedDb.clone()
+
+    @classmethod
+    def tx_active(self):
+        return self._inside_tx
+
+    @classmethod
+    def beginTx(self):
+        self.db.query("START TRANSACTION")
+        self._inside_tx = True
+
+    @classmethod
+    def commitTx(self):
+        self.db.query("COMMIT")
+        self._inside_tx = False
diff --git a/hive/indexer/follow.py b/hive/indexer/follow.py
index 92deadf60..c3485d106 100644
--- a/hive/indexer/follow.py
+++ b/hive/indexer/follow.py
@@ -9,9 +9,9 @@ from hive.db.db_state import DbState
 from hive.indexer.accounts import Accounts
 from hive.indexer.notify import Notify
 
-log = logging.getLogger(__name__)
+from hive.indexer.db_adapter_holder import DbAdapterHolder
 
-DB = Db.instance()
+log = logging.getLogger(__name__)
 
 FOLLOWERS = 'followers'
 FOLLOWING = 'following'
@@ -65,7 +65,7 @@ def _flip_dict(dict_to_flip):
             flipped[value] = [key]
     return flipped
 
-class Follow:
+class Follow(DbAdapterHolder):
     """Handles processing of incoming follow ups and flushing to db."""
 
     follow_items_to_flush = dict()
@@ -99,7 +99,7 @@ class Follow:
         else:
             old_state = cls._get_follow_db_state(op['flr'], op['flg'])
             # insert or update state
-            DB.query(FOLLOW_ITEM_INSERT_QUERY, **op)
+            cls.db.query(FOLLOW_ITEM_INSERT_QUERY, **op)
             if new_state == 1:
                 Follow.follow(op['flr'], op['flg'])
                 if old_state is None:
@@ -142,7 +142,7 @@ class Follow:
         sql = """SELECT state FROM hive_follows
                   WHERE follower = :follower
                     AND following = :following"""
-        return DB.query_one(sql, follower=follower, following=following)
+        return cls.db.query_one(sql, follower=follower, following=following)
 
 
     # -- stat tracking --
@@ -206,7 +206,7 @@ class Follow:
             else:
                 query = sql_prefix + ",".join(values)
                 query += sql_postfix
-                DB.query(query)
+                cls.db.query(query)
                 values.clear()
                 values.append("({}, {}, '{}', {}, {}, {})".format(follow_item['flr'], follow_item['flg'],
                                                                   follow_item['at'], follow_item['state'],
@@ -217,7 +217,7 @@ class Follow:
         if len(values) > 0:
             query = sql_prefix + ",".join(values)
             query += sql_postfix
-            DB.query(query)
+            cls.db.query(query)
 
         cls.follow_items_to_flush.clear()
 
@@ -239,7 +239,7 @@ class Follow:
             return 0
 
         start = perf()
-        DB.batch_queries(sqls, trx=trx)
+        cls.db.batch_queries(sqls, trx=trx)
         if trx:
             log.info("[SYNC] flushed %d follow deltas in %ds",
                      updated, perf() - start)
@@ -263,7 +263,7 @@ class Follow:
                    following = (SELECT COUNT(*) FROM hive_follows WHERE state = 1 AND follower  = hive_accounts.id)
              WHERE id IN :ids
         """
-        DB.query(sql, ids=tuple(ids))
+        cls.db.query(sql, ids=tuple(ids))
 
     @classmethod
     def force_recount(cls):
@@ -281,7 +281,7 @@ class Follow:
                LEFT JOIN hive_follows hf ON id = hf.following AND state = 1
                 GROUP BY id);
         """
-        DB.query(sql)
+        cls.db.query(sql)
 
         log.info("[SYNC] update follower counts")
         sql = """
@@ -291,4 +291,4 @@ class Follow:
             UPDATE hive_accounts SET following = num FROM following_counts
              WHERE id = account_id AND following != num;
         """
-        DB.query(sql)
+        cls.db.query(sql)
diff --git a/hive/indexer/post_data_cache.py b/hive/indexer/post_data_cache.py
index 5cdcc3a4b..7b9f3f941 100644
--- a/hive/indexer/post_data_cache.py
+++ b/hive/indexer/post_data_cache.py
@@ -1,14 +1,17 @@
 import logging
+import logging
 from hive.utils.normalize import escape_characters
 from hive.db.adapter import Db
 
+from hive.indexer.db_adapter_holder import DbAdapterHolder
+
 log = logging.getLogger(__name__)
-DB = Db.instance()
 
-class PostDataCache(object):
+class PostDataCache(DbAdapterHolder):
     """ Procides cache for DB operations on post data table in order to speed up initial sync """
     _data = {}
 
+
     @classmethod
     def is_cached(cls, pid):
         """ Check if data is cached """
@@ -28,7 +31,7 @@ class PostDataCache(object):
             sql = """
                   SELECT hpd.body FROM hive_post_data hpd WHERE hpd.id = :post_id;
                   """
-            row = DB.query_row(sql, post_id = pid)
+            row = cls.db.query_row(sql, post_id = pid)
             post_data = dict(row)
         return post_data['body']
 
@@ -36,6 +39,7 @@ class PostDataCache(object):
     def flush(cls, print_query = False):
         """ Flush data from cache to db """
         if cls._data:
+            cls.beginTx()
             sql = """
                 INSERT INTO 
                     hive_post_data (id, title, preview, img_url, body, json) 
@@ -66,7 +70,9 @@ class PostDataCache(object):
             if(print_query):
                 log.info("Executing query:\n{}".format(sql))
 
-            DB.query(sql)
+            cls.db.query(sql)
+            cls.commitTx()
+
         n = len(cls._data.keys())
         cls._data.clear()
         return n
diff --git a/hive/indexer/reputations.py b/hive/indexer/reputations.py
index b35680159..127c9fce1 100644
--- a/hive/indexer/reputations.py
+++ b/hive/indexer/reputations.py
@@ -1,25 +1,27 @@
 """ Reputation update support """
 
 import logging
+from hive.indexer.db_adapter_holder import DbAdapterHolder
+
 log = logging.getLogger(__name__)
 
 CACHED_ITEMS_LIMIT = 200
 
-class Reputations:
+class Reputations(DbAdapterHolder):
     _queries = []
-    _db = None
-
-    def __init__(self, database):
-        log.info("Cloning database...")
-        self._db = database.clone()
-        assert self._db is not None, "Database not cloned"
-        log.info("Database object at: {}".format(self._db))
 
+    @classmethod
     def process_vote(self, block_num, effective_vote_op):
         self._queries.append("\nSELECT process_reputation_data({}, '{}', '{}', '{}', {});".format(block_num, effective_vote_op['author'], effective_vote_op['permlink'],
              effective_vote_op['voter'], effective_vote_op['rshares']))
 
+    @classmethod
     def flush(self):
+        if not self._queries:
+            return 0
+
+        self.beginTx()
+
         query = ""
         i = 0
         items = 0
@@ -28,16 +30,18 @@ class Reputations:
             i = i + 1
             items = items + 1
             if items >= CACHED_ITEMS_LIMIT:
-                self._db.query_no_return(query)
+                self.db.query_no_return(query)
                 query = ""
                 items = 0
 
         if items >= CACHED_ITEMS_LIMIT:
-            self._db.query_no_return(query)
+            self.db.query_no_return(query)
             query = ""
             items = 0
 
         n = len(self._queries)
         self._queries.clear()
+
+        self.commitTx()
         return n
 
diff --git a/hive/indexer/sync.py b/hive/indexer/sync.py
index e91f64cf6..599ca2c1a 100644
--- a/hive/indexer/sync.py
+++ b/hive/indexer/sync.py
@@ -91,7 +91,7 @@ def _vops_provider(node, queue, lbound, ubound, chunk_size):
     except Exception:
         log.exception("Exception caught during fetching vops...")
 
-def _block_consumer(node, blocksProcessor, blocksQueue, vopsQueue, is_initial_sync, lbound, ubound, chunk_size):
+def _block_consumer(node, blocksQueue, vopsQueue, is_initial_sync, lbound, ubound, chunk_size):
     from hive.utils.stats import minmax
     is_debug = log.isEnabledFor(10)
     num = 0
@@ -128,7 +128,7 @@ def _block_consumer(node, blocksProcessor, blocksQueue, vopsQueue, is_initial_sy
             timer.batch_start()
             
             block_start = perf()
-            blocksProcessor.process_multi(blocks, preparedVops, node, is_initial_sync)
+            Blocks.process_multi(blocks, preparedVops, node, is_initial_sync)
             block_end = perf()
 
             timer.batch_lap()
@@ -184,7 +184,7 @@ def _node_data_provider(self, is_initial_sync, lbound, ubound, chunk_size):
         try:
             pool.submit(_block_provider, self._steem, blocksQueue, lbound, ubound, chunk_size)
             pool.submit(_vops_provider, self._steem, vopsQueue, lbound, ubound, chunk_size)
-            blockConsumerFuture = pool.submit(_block_consumer, self._steem, self._blocksProcessor, blocksQueue, vopsQueue, is_initial_sync, lbound, ubound, chunk_size)
+            blockConsumerFuture = pool.submit(_block_consumer, self._steem, blocksQueue, vopsQueue, is_initial_sync, lbound, ubound, chunk_size)
 
             blockConsumerFuture.result()
             if not CONTINUE_PROCESSING and blocksQueue.empty() and vopsQueue.empty():
@@ -211,15 +211,13 @@ class Sync:
         log.info("Using hived url: `%s'", self._conf.get('steemd_url'))
 
         self._steem = conf.steem()
-        self._blocksProcessor = None
 
     def run(self):
         """Initialize state; setup/recovery checks; sync and runloop."""
 
         # ensure db schema up to date, check app status
         DbState.initialize()
-        Blocks.set_reputations_processor(Reputations(self._db))
-        self._blocksProcessor = Blocks()
+        Blocks.setup_db_access(self._db)
 
         # prefetch id->name and id->rank memory maps
         Accounts.load_ids()
@@ -235,16 +233,16 @@ class Sync:
         Community.recalc_pending_payouts()
 
         if DbState.is_initial_sync():
-            last_imported_block = self._blocksProcessor.head_num()
+            last_imported_block = Blocks.head_num()
             # resume initial sync
             self.initial()
             if not CONTINUE_PROCESSING:
                 return
-            current_imported_block = self._blocksProcessor.head_num()
+            current_imported_block = Blocks.head_num()
             DbState.finish_initial_sync(current_imported_block, last_imported_block)
         else:
             # recover from fork
-            self._blocksProcessor.verify_head(self._steem)
+            Blocks.verify_head(self._steem)
 
         self._update_chain_state()
 
@@ -319,7 +317,7 @@ class Sync:
         """Fast sync strategy: read/process blocks in batches."""
         steemd = self._steem
 
-        lbound = self._blocksProcessor.head_num() + 1
+        lbound = Blocks.head_num() + 1
         ubound = self._conf.get('test_max_block') or steemd.last_irreversible()
 
         count = ubound - lbound
@@ -344,7 +342,7 @@ class Sync:
             timer.batch_lap()
 
             # process blocks
-            self._blocksProcessor.process_multi(blocks, preparedVops, steemd, is_initial_sync)
+            Blocks.process_multi(blocks, preparedVops, steemd, is_initial_sync)
             timer.batch_finish(len(blocks))
 
             _prefix = ("[SYNC] Got block %d @ %s" % (
@@ -366,13 +364,13 @@ class Sync:
 
         assert self._blocksProcessor 
         steemd = self._steem
-        hive_head = self._blocksProcessor.head_num()
+        hive_head = Blocks.head_num()
 
         for block in steemd.stream_blocks(hive_head + 1, trail_blocks, max_gap):
             start_time = perf()
 
             self._db.query("START TRANSACTION")
-            num = self._blocksProcessor.process(block, {}, steemd)
+            num = Blocks.process(block, {}, steemd)
             follows = Follow.flush(trx=False)
             accts = Accounts.flush(steemd, trx=False, spread=8)
             self._db.query("COMMIT")
diff --git a/hive/indexer/tags.py b/hive/indexer/tags.py
index 44d942395..1179c3c50 100644
--- a/hive/indexer/tags.py
+++ b/hive/indexer/tags.py
@@ -1,12 +1,12 @@
 import logging
 from hive.db.adapter import Db
+from hive.indexer.db_adapter_holder import DbAdapterHolder
 
 log = logging.getLogger(__name__)
-DB = Db.instance()
 
 from hive.utils.normalize import escape_characters
 
-class Tags(object):
+class Tags(DbAdapterHolder):
     """ Tags cache """
     _tags = []
 
@@ -17,8 +17,9 @@ class Tags(object):
 
     @classmethod
     def flush(cls):
-        """ Flush tags to table """
+        """ Flush tags to table """        
         if cls._tags:
+            cls.beginTx()
             limit = 1000
 
             sql = """
@@ -32,11 +33,11 @@ class Tags(object):
                 values.append("({})".format(escape_characters(tag[1])))
                 if len(values) >= limit:
                     tag_query = str(sql)
-                    DB.query(tag_query.format(','.join(values)))
+                    cls.db.query(tag_query.format(','.join(values)))
                     values.clear()
             if len(values) > 0:
                 tag_query = str(sql)
-                DB.query(tag_query.format(','.join(values)))
+                cls.db.query(tag_query.format(','.join(values)))
                 values.clear()
 
             sql = """
@@ -62,13 +63,13 @@ class Tags(object):
                 values.append("({}, {})".format(tag[0], escape_characters(tag[1])))
                 if len(values) >= limit:
                     tag_query = str(sql)
-                    DB.query(tag_query.format(','.join(values)))
+                    cls.db.query(tag_query.format(','.join(values)))
                     values.clear()
             if len(values) > 0:
                 tag_query = str(sql)
-                DB.query(tag_query.format(','.join(values)))
+                cls.db.query(tag_query.format(','.join(values)))
                 values.clear()
-            
+            cls.commitTx()
         n = len(cls._tags)
         cls._tags.clear()
         return n
diff --git a/hive/indexer/votes.py b/hive/indexer/votes.py
index a2d82e99e..6318bc2a5 100644
--- a/hive/indexer/votes.py
+++ b/hive/indexer/votes.py
@@ -4,11 +4,11 @@ import logging
 
 from hive.db.db_state import DbState
 from hive.db.adapter import Db
+from hive.indexer.db_adapter_holder import DbAdapterHolder
 
 log = logging.getLogger(__name__)
-DB = Db.instance()
 
-class Votes:
+class Votes(DbAdapterHolder):
     """ Class for managing posts votes """
     _votes_data = {}
 
@@ -62,9 +62,12 @@ class Votes:
     @classmethod
     def flush(cls):
         """ Flush vote data from cache to database """
+
         cls.inside_flush = True
         n = 0
         if cls._votes_data:
+            cls.beginTx()
+
             sql = """
                 INSERT INTO hive_votes
                 (post_id, voter_id, author_id, permlink_id, weight, rshares, vote_percent, last_update, block_num, is_effective)
@@ -105,16 +108,19 @@ class Votes:
                 if len(values) >= values_limit:
                     values_str = ','.join(values)
                     actual_query = sql.format(values_str)
-                    DB.query(actual_query)
+                    cls.db.query(actual_query)
                     values.clear()
 
             if len(values) > 0:
                 values_str = ','.join(values)
                 actual_query = sql.format(values_str)
-                DB.query(actual_query)
+                cls.db.query(actual_query)
                 values.clear()
 
             n = len(cls._votes_data)
             cls._votes_data.clear()
+            cls.commitTx()
+
         cls.inside_flush = False
+
         return n
-- 
GitLab