From 9c57fb7c8b3ea8c58ae3c4cb99631976271d16ef Mon Sep 17 00:00:00 2001
From: Nicolas80 <nicolas.pmail@protonmail.com>
Date: Wed, 15 Jan 2025 15:44:38 +0100
Subject: [PATCH] Small cleanup for data's db connection; adding method
 connect_db that returns a &DatabaseConnection (or panics with explanation
 message)

Also changed some methods visibility so only connect_db is visible for usage.
---
 src/commands/vault.rs | 70 ++++++++++++++++---------------------------
 src/conf.rs           |  4 +--
 src/data.rs           | 10 +++++--
 3 files changed, 35 insertions(+), 49 deletions(-)

diff --git a/src/commands/vault.rs b/src/commands/vault.rs
index 7aeda3e..adf9245 100644
--- a/src/commands/vault.rs
+++ b/src/commands/vault.rs
@@ -117,35 +117,23 @@ fn decrypt(input: &[u8], passphrase: String) -> Result<Vec<u8>, age::DecryptErro
 
 /// handle vault commands
 pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliError> {
+	let db = data.connect_db();
+
 	// match subcommand
 	match command {
 		Subcommand::List(choice) => match choice {
 			ListChoice::All => {
-				let derivations = vault_derivation::list_all_derivations_in_order(
-					data.connection.as_ref().unwrap(),
-				)
-				.await?;
+				let derivations = vault_derivation::list_all_derivations_in_order(db).await?;
 
-				let table = compute_vault_derivations_table(
-					data.connection.as_ref().unwrap(),
-					&derivations,
-				)
-				.await?;
+				let table = compute_vault_derivations_table(db, &derivations).await?;
 
 				println!("available SS58 Addresses:");
 				println!("{table}");
 			}
 			ListChoice::Account => {
-				let derivations = vault_derivation::list_all_root_derivations_in_order(
-					data.connection.as_ref().unwrap(),
-				)
-				.await?;
+				let derivations = vault_derivation::list_all_root_derivations_in_order(db).await?;
 
-				let table = compute_vault_derivations_table(
-					data.connection.as_ref().unwrap(),
-					&derivations,
-				)
-				.await?;
+				let table = compute_vault_derivations_table(db, &derivations).await?;
 
 				println!("available <Account> SS58 Addresses:");
 				println!("{table}");
@@ -157,16 +145,12 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 					retrieve_vault_derivation(&data, address_or_vault_name).await?;
 
 				let linked_derivations = vault_derivation::fetch_all_linked_derivations_in_order(
-					data.connection.as_ref().unwrap(),
+					db,
 					&selected_derivation.root_address,
 				)
 				.await?;
 
-				let table = compute_vault_derivations_table(
-					data.connection.as_ref().unwrap(),
-					&linked_derivations,
-				)
-				.await?;
+				let table = compute_vault_derivations_table(db, &linked_derivations).await?;
 
 				println!("available SS58 Addresses linked to {selected_derivation}:");
 				println!("{table}");
@@ -219,7 +203,7 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 			println!("Trying to import for SS58 address :'{}'", address_to_import);
 
 			if let Some(derivation) = vault_derivation::Entity::find_by_id(&address_to_import)
-				.one(data.connection.as_ref().unwrap())
+				.one(db)
 				.await?
 			{
 				println!(
@@ -228,17 +212,13 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 				);
 
 				let linked_derivations = vault_derivation::fetch_all_linked_derivations_in_order(
-					data.connection.as_ref().unwrap(),
+					db,
 					&derivation.root_address.clone(),
 				)
 				.await?;
 				println!("Here are all the SS58 Addresses linked to it in the vault:");
 
-				let table = compute_vault_derivations_table(
-					data.connection.as_ref().unwrap(),
-					&linked_derivations,
-				)
-				.await?;
+				let table = compute_vault_derivations_table(db, &linked_derivations).await?;
 				println!("{table}");
 
 				return Ok(());
@@ -250,7 +230,7 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 			println!("(Optional) Enter a name for the vault entry");
 			let name = inputs::prompt_vault_name()?;
 
-			let txn = data.connection.as_ref().unwrap().begin().await?;
+			let txn = db.begin().await?;
 
 			let _derivation = create_derivation_for_vault_data_to_import(
 				&txn,
@@ -280,7 +260,7 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 			}
 
 			let vault_account = vault_account::Entity::find_by_id(&root_derivation.address)
-				.one(data.connection.as_ref().unwrap())
+				.one(db)
 				.await?
 				.ok_or(GcliError::Input(format!(
 					"Could not find vault_account for address:'{}'",
@@ -315,7 +295,7 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 			let derivation_address: String = derivation_keypair.address().to_string();
 
 			let check_derivation = vault_derivation::Entity::find_by_id(&derivation_address)
-				.one(data.connection.as_ref().unwrap())
+				.one(db)
 				.await?;
 
 			if check_derivation.is_some() {
@@ -332,12 +312,12 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 				path: Set(Some(derivation_path)),
 				root_address: Set(root_derivation.root_address.clone()),
 			};
-			let derivation = derivation.insert(data.connection.as_ref().unwrap()).await?;
+			let derivation = derivation.insert(db).await?;
 			println!("Created: {}", derivation);
 		}
 		Subcommand::Rename { address } => {
 			let derivation = vault_derivation::Entity::find_by_id(address.to_string())
-				.one(data.connection.as_ref().unwrap())
+				.one(db)
 				.await?;
 
 			if derivation.is_none() {
@@ -359,7 +339,7 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 			let old_name = derivation.name.clone();
 			let mut derivation: vault_derivation::ActiveModel = derivation.into();
 			derivation.name = Set(name.clone());
-			let _derivation = derivation.update(data.connection.as_ref().unwrap()).await?;
+			let _derivation = derivation.update(db).await?;
 			println!(
 				"Renamed address:'{address}' from {:?} to {:?}",
 				old_name, name
@@ -371,7 +351,7 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 			let derivation = retrieve_vault_derivation(&data, address_or_vault_name).await?;
 			let address_to_delete = derivation.address.clone();
 
-			let txn = data.connection.as_ref().unwrap().begin().await?;
+			let txn = db.begin().await?;
 
 			//If deleting a root derivation; also delete the vault account and all linked derivations
 			if derivation.path.is_none() {
@@ -430,7 +410,7 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 			for address in vault_key_addresses {
 				//Check if we already have a vault_derivation for that address
 				let derivation = vault_derivation::Entity::find_by_id(&address)
-					.one(data.connection.as_ref().unwrap())
+					.one(db)
 					.await?;
 
 				if derivation.is_some() {
@@ -459,7 +439,7 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 					key_pair: vault_data_from_file.key_pair,
 				};
 
-				let txn = data.connection.as_ref().unwrap().begin().await?;
+				let txn = db.begin().await?;
 
 				let derivation = create_derivation_for_vault_data_to_import(
 					&txn,
@@ -646,7 +626,7 @@ pub async fn retrieve_vault_derivation<T: AddressOrVaultName>(
 
 		let derivation = vault_derivation::Entity::find()
 			.filter(vault_derivation::Column::Name.eq(Some(name.clone())))
-			.one(data.connection.as_ref().unwrap())
+			.one(data.connect_db())
 			.await?;
 
 		let derivation = derivation.ok_or(GcliError::Input(format!(
@@ -661,7 +641,7 @@ pub async fn retrieve_vault_derivation<T: AddressOrVaultName>(
 						vault_derivation::Column::RootAddress.eq(derivation.root_address.clone()),
 					)
 					.filter(vault_derivation::Column::Path.eq(Some(path.clone())))
-					.one(data.connection.as_ref().unwrap())
+					.one(data.connect_db())
 					.await?;
 
 				sub_derivation.ok_or(GcliError::Input(format!(
@@ -671,7 +651,7 @@ pub async fn retrieve_vault_derivation<T: AddressOrVaultName>(
 		}
 	} else if let Some(address) = address_or_vault_name.address() {
 		let derivation = vault_derivation::Entity::find_by_id(address.to_string())
-			.one(data.connection.as_ref().unwrap())
+			.one(data.connect_db())
 			.await?;
 
 		derivation.ok_or(GcliError::Input(format!(
@@ -899,12 +879,12 @@ pub async fn try_fetch_key_pair(
 	address: AccountId,
 ) -> Result<Option<KeyPair>, GcliError> {
 	if let Some(derivation) = vault_derivation::Entity::find_by_id(address.to_string())
-		.one(data.connection.as_ref().unwrap())
+		.one(data.connect_db())
 		.await?
 	{
 		if let Some(vault_account) =
 			vault_account::Entity::find_by_id(derivation.root_address.clone())
-				.one(data.connection.as_ref().unwrap())
+				.one(data.connect_db())
 				.await?
 		{
 			let root_secret_suri = retrieve_suri_from_vault_account(&vault_account)?;
diff --git a/src/conf.rs b/src/conf.rs
index 389dbd5..95f5279 100644
--- a/src/conf.rs
+++ b/src/conf.rs
@@ -84,9 +84,9 @@ pub async fn handle_command(data: Data, command: Subcommand) -> Result<(), GcliE
 		}
 		Subcommand::Show => {
 			println!("{}", data.cfg);
-			if let Some(account_id) = data.cfg.address {
+			if let Some(ref account_id) = data.cfg.address {
 				if let Some(derivation) = vault_derivation::fetch_vault_derivation(
-					data.connection.as_ref().unwrap(),
+					data.connect_db(),
 					account_id.to_string().as_str(),
 				)
 				.await?
diff --git a/src/data.rs b/src/data.rs
index 5061c5d..67d9928 100644
--- a/src/data.rs
+++ b/src/data.rs
@@ -35,7 +35,7 @@ pub struct Data {
 	/// config
 	pub cfg: conf::Config,
 	/// database connection
-	pub connection: Option<DatabaseConnection>,
+	connection: Option<DatabaseConnection>,
 	/// rpc to substrate client
 	pub client: Option<Client>,
 	/// graphql to duniter-indexer
@@ -105,6 +105,12 @@ impl Data {
 	}
 	// --- getters ---
 	// the "unwrap" should not fail if data is well prepared
+	/// Returns the DatabaseConnection reference
+	pub fn connect_db(&self) -> &DatabaseConnection {
+		self.connection
+			.as_ref()
+			.expect("Database connection is not available")
+	}
 	pub fn client(&self) -> &Client {
 		self.client.as_ref().expect("must build client first")
 	}
@@ -236,7 +242,7 @@ impl Data {
 	}
 
 	/// build a database connection
-	pub async fn build_connection(mut self) -> Result<Self, GcliError> {
+	async fn build_connection(mut self) -> Result<Self, GcliError> {
 		let data_dir = self.project_dir.data_dir();
 		let connection = database::build_sqlite_connection(data_dir, SQLITE_DB_FILENAME).await?;
 		self.connection = Some(connection);
-- 
GitLab