From 96dca79e3c37403c86869ca9736308b315abcfbf Mon Sep 17 00:00:00 2001
From: Inso <insomniak.fr@gmail.com>
Date: Wed, 7 Oct 2015 08:53:25 +0200
Subject: [PATCH] Handle rollbacks

---
 src/cutecoin/core/account.py   | 39 +++++++++++++++++
 src/cutecoin/core/app.py       |  5 ++-
 src/cutecoin/core/transfer.py  | 10 ++---
 src/cutecoin/core/txhistory.py | 80 ++++++++++++++++++++++------------
 src/cutecoin/core/wallet.py    |  9 ++++
 5 files changed, 110 insertions(+), 33 deletions(-)

diff --git a/src/cutecoin/core/account.py b/src/cutecoin/core/account.py
index 5e6a66f8..2a5529be 100644
--- a/src/cutecoin/core/account.py
+++ b/src/cutecoin/core/account.py
@@ -186,6 +186,45 @@ class Account(QObject):
             w.init_cache(app, community)
             w.refresh_transactions(community, received_list)
 
+    def rollback_transaction(self, app, community):
+        """
+        Refresh the local account cache
+        This needs n_wallets * n_communities cache refreshing to end
+
+        .. note:: emit the Account pyqtSignal loading_progressed during refresh
+        """
+        logging.debug("Start refresh transactions")
+        loaded_wallets = 0
+        received_list = []
+        values = {}
+        maximums = {}
+
+        def progressing(value, maximum, hash):
+            #logging.debug("Loading = {0} : {1} : {2}".format(value, maximum, loaded_wallets))
+            values[hash] = value
+            maximums[hash] = maximum
+            account_value = sum(values.values())
+            account_max = sum(maximums.values())
+            self.loading_progressed.emit(community, account_value, account_max)
+
+        def wallet_finished(received):
+            logging.debug("Finished loading wallet")
+            nonlocal loaded_wallets
+            loaded_wallets += 1
+            if loaded_wallets == len(self.wallets):
+                logging.debug("All wallets loaded")
+                self._refreshing = False
+                self.loading_finished.emit(community, received_list)
+                for w in self.wallets:
+                    w.refresh_progressed.disconnect(progressing)
+                    w.refresh_finished.disconnect(wallet_finished)
+
+        for w in self.wallets:
+            w.refresh_progressed.connect(progressing)
+            w.refresh_finished.connect(wallet_finished)
+            w.init_cache(app, community)
+            w.rollback_transactions(community, received_list)
+
     def set_display_referential(self, index):
         self._current_ref = index
 
diff --git a/src/cutecoin/core/app.py b/src/cutecoin/core/app.py
index 3661f322..cc26dd67 100644
--- a/src/cutecoin/core/app.py
+++ b/src/cutecoin/core/app.py
@@ -242,7 +242,10 @@ class Application(QObject):
                 def refresh_tx(blocknumber, co=community):
                     account.refresh_transactions(self, co)
                 community.network.new_block_mined.connect(refresh_tx)
-                account.refresh_transactions(self, community)
+
+                def rollback_tx(blocknumber, co=community):
+                    account.rollback_transaction(self, co)
+                community.network.new_block_mined.connect(rollback_tx)
 
     def load_cache(self, account):
         """
diff --git a/src/cutecoin/core/transfer.py b/src/cutecoin/core/transfer.py
index 6444fe62..028e3082 100644
--- a/src/cutecoin/core/transfer.py
+++ b/src/cutecoin/core/transfer.py
@@ -163,7 +163,7 @@ class Transfer(QObject):
         """
         if not rollback:
             for tx in block.transactions:
-                if tx.hash == self.sha_hash:
+                if tx.sha_hash == self.sha_hash:
                     return False
             if block.time > self.metadata['time'] + mediantime_target*mediantime_blocks:
                 return True
@@ -179,7 +179,7 @@ class Transfer(QObject):
         """
         if not rollback:
             for tx in block.transactions:
-                if tx.hash == self.sha_hash:
+                if tx.sha_hash == self.sha_hash:
                     return True
         return False
 
@@ -224,7 +224,7 @@ class Transfer(QObject):
             if not block or block.blockid != self.blockid:
                 return True
             else:
-                return self.sha_hash not in [t.hash for t in block.transactions]
+                return self.sha_hash not in [t.sha_hash for t in block.transactions]
         return False
 
     def _rollback_still_present(self, rollback, block):
@@ -235,7 +235,7 @@ class Transfer(QObject):
         :return: True if the transfer is found in the block
         """
         if rollback and block.blockid == self.blockid:
-            return self.sha_hash in [t.hash for t in block.transactions]
+            return self.sha_hash in [t.sha_hash for t in block.transactions]
         return False
 
     def _rollback_and_local(self, rollback, block):
@@ -246,7 +246,7 @@ class Transfer(QObject):
         :return: True if the transfer is found in the block
         """
         if rollback and self._locally_created and block.blockid == self.blockid:
-            return self.sha_hash not in [t.hash for t in block.transactions]
+            return self.sha_hash not in [t.sha_hash for t in block.transactions]
         return False
 
     def _is_locally_created(self):
diff --git a/src/cutecoin/core/txhistory.py b/src/cutecoin/core/txhistory.py
index a6d358f1..61bce67a 100644
--- a/src/cutecoin/core/txhistory.py
+++ b/src/cutecoin/core/txhistory.py
@@ -15,7 +15,6 @@ class TxHistory():
         self.app = app
         self._stop_coroutines = False
         self._running_refresh = []
-        self._block_to = None
         self._transfers = []
         self.available_sources = []
         self._dividends = []
@@ -224,7 +223,7 @@ class TxHistory():
         return {}
 
     @asyncio.coroutine
-    def _refresh(self, community, block_number_from, received_list):
+    def _refresh(self, community, block_number_from, block_to, received_list):
         """
         Refresh last transactions
 
@@ -234,16 +233,16 @@ class TxHistory():
         new_transfers = []
         new_dividends = []
         try:
-            logging.debug("Refresh from : {0} to {1}".format(block_number_from, self._block_to['number']))
+            logging.debug("Refresh from : {0} to {1}".format(block_number_from, block_to['number']))
             dividends = yield from self.request_dividends(community, block_number_from)
             with_tx_data = yield from community.bma_access.future_request(bma.blockchain.TX)
             members_pubkeys = yield from community.members_pubkeys()
             fork_window = community.network.fork_window(members_pubkeys)
             blocks_with_tx = with_tx_data['result']['blocks']
-            while block_number_from <= self._block_to['number']:
+            while block_number_from <= block_to['number']:
                 udid = 0
                 for d in [ud for ud in dividends if ud['block_number'] == block_number_from]:
-                    state = TransferState.VALIDATED if block_number_from + fork_window <= self._block_to['number'] \
+                    state = TransferState.VALIDATED if block_number_from + fork_window <= block_to['number'] \
                         else TransferState.VALIDATING
 
                     if d['block_number'] not in [ud['block_number'] for ud in self._dividends]:
@@ -260,15 +259,15 @@ class TxHistory():
                 # We parse only blocks with transactions
                 if block_number_from in blocks_with_tx:
                     transfers = yield from self._parse_block(community, block_number_from,
-                                                             received_list, self._block_to,
+                                                             received_list, block_to,
                                                              udid + len(new_transfers))
                     new_transfers += transfers
 
-                self.wallet.refresh_progressed.emit(block_number_from, self._block_to['number'], self.wallet.pubkey)
+                self.wallet.refresh_progressed.emit(block_number_from, block_to['number'], self.wallet.pubkey)
                 block_number_from += 1
 
-            signed_raw = "{0}{1}\n".format(self._block_to['raw'],
-                                       self._block_to['signature'])
+            signed_raw = "{0}{1}\n".format(block_to['raw'],
+                                       block_to['signature'])
             block_to = Block.from_signed_raw(signed_raw)
             for transfer in [t for t in self._transfers + new_transfers if t.state == TransferState.VALIDATING]:
                 transfer.run_state_transitions((False, block_to, fork_window))
@@ -282,7 +281,7 @@ class TxHistory():
 
             parameters = yield from community.parameters()
             for transfer in [t for t in self._transfers if t.state == TransferState.AWAITING]:
-                transfer.run_state_transitions((False, self._block_to,
+                transfer.run_state_transitions((False, block_to,
                                                 parameters['avgGenTime'], parameters['medianTimeBlocks']))
         except NoPeerAvailable as e:
             logging.debug(str(e))
@@ -320,10 +319,13 @@ class TxHistory():
                 if '404' in str(e):
                     block = None
                     tries += 1
-        for transfer in [t for t in self._transfers
-                         if t.state in (TransferState.VALIDATING, TransferState.VALIDATED) and
-                         t.blockid.number == block_number]:
-            return not transfer.run_state_transitions((True, block_doc))
+        if block_doc:
+            for transfer in [t for t in self._transfers
+                             if t.state in (TransferState.VALIDATING, TransferState.VALIDATED) and
+                             t.blockid.number == block_number]:
+                return not transfer.run_state_transitions((True, block_doc))
+        else:
+            return False
 
     @asyncio.coroutine
     def _rollback(self, community):
@@ -334,13 +336,14 @@ class TxHistory():
         :param cutecoin.core.Community community: The community
         """
         try:
-            logging.debug("Rollback from : {0}".format(self._block_to['number']))
+            logging.debug("Rollback from : {0}".format(self.latest_block))
             # We look for the block goal to check for rollback,
             #  depending on validating and validated transfers...
             tx_blocks = [tx.blockid.number for tx in self._transfers
-                      if tx.state in (TransferState.VALIDATED, TransferState.VALIDATING) \
-                     and tx.blockid is not None]
-            for block_number in tx_blocks:
+                          if tx.state in (TransferState.VALIDATED, TransferState.VALIDATING) and
+                          tx.blockid is not None]
+            for i, block_number in enumerate(tx_blocks):
+                self.wallet.refresh_progressed.emit(i, len(tx_blocks), self.wallet.pubkey)
                 if (yield from self._check_block(community, block_number)):
                     return
         except NoPeerAvailable:
@@ -363,20 +366,43 @@ class TxHistory():
                           if ud['state'] in (TransferState.AWAITING, TransferState.VALIDATING)]
                 blocks = tx_blocks + ud_blocks + \
                          [max(0, self.latest_block - community.network.fork_window(members_pubkeys))]
-                parsed_block = min(set(blocks))
-                self._block_to = current_block
+                block_from = min(set(blocks))
 
-                # We wait for current refresh coroutines
-                if len(self._running_refresh) > 0:
-                    logging.debug("Wait for the end of previous refresh")
-                    done, pending = yield from asyncio.wait(self._running_refresh)
-                    for cor in done:
-                        self._running_refresh.remove(cor)
+                yield from self._wait_for_previous_refresh()
 
                 # Then we start a new one
-                task = asyncio.async(self._refresh(community, parsed_block, received_list))
+                logging.debug("Starts a new refresh")
+                task = asyncio.async(self._refresh(community, block_from, current_block, received_list))
                 self._running_refresh.append(task)
         except ValueError as e:
             logging.debug("Block not found")
         except NoPeerAvailable:
             logging.debug("No peer available")
+
+    @asyncio.coroutine
+    def rollback(self, community, received_list):
+        yield from self._wait_for_previous_refresh()
+        # Then we start a new one
+        logging.debug("Starts a new refresh")
+        task = asyncio.async(self._rollback(community))
+        self._running_refresh.append(task)
+
+        # Then we start a refresh to check for new transactions
+        yield from self.refresh(community, received_list)
+
+    @asyncio.coroutine
+    def _wait_for_previous_refresh(self):
+        # We wait for current refresh coroutines
+        if len(self._running_refresh) > 0:
+            logging.debug("Wait for the end of previous refresh")
+            done, pending = yield from asyncio.wait(self._running_refresh)
+            for cor in done:
+                try:
+                    self._running_refresh.remove(cor)
+                except ValueError:
+                    logging.debug("Task already removed.")
+            for p in pending:
+                logging.debug("Still waiting for : {0}".format(p))
+            logging.debug("Previous refresh finished")
+        else:
+            logging.debug("No previous refresh")
diff --git a/src/cutecoin/core/wallet.py b/src/cutecoin/core/wallet.py
index 2b4342de..c2e906ed 100644
--- a/src/cutecoin/core/wallet.py
+++ b/src/cutecoin/core/wallet.py
@@ -112,6 +112,15 @@ class Wallet(QObject):
         logging.debug("Refresh transactions for {0}".format(self.pubkey))
         asyncio.async(self.caches[community.currency].refresh(community, received_list))
 
+    def rollback_transactions(self, community, received_list):
+        """
+        Rollback the transactions of this wallet for the specified community.
+
+        :param community: The community to refresh its cache
+        """
+        logging.debug("Refresh transactions for {0}".format(self.pubkey))
+        asyncio.async(self.caches[community.currency].rollback(community, received_list))
+
     def check_password(self, salt, password):
         """
         Check if wallet password is ok.
-- 
GitLab