From 5453bced713a699ab6698f1483c244b9b2fb3820 Mon Sep 17 00:00:00 2001 From: SinTan1729 Date: Sat, 27 May 2023 18:35:06 -0500 Subject: [PATCH] new: Added inverse method --- src/lib.rs | 90 +++++++++++++++++++++++++++++++++++++++++++++++++--- src/tests.rs | 21 ++++++++++++ 2 files changed, 107 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index af8d2f3..6644769 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ //! with any type that implement [`Add`], [`Sub`], [`Mul`], //! [`Zero`], [`Neg`] and [`Copy`]. Additional properties might be //! needed for certain operations. +//! //! I created it mostly to learn using generic types //! and traits. //! @@ -184,7 +185,7 @@ impl Matrix { /// ``` /// use matrix_basic::Matrix; /// let m = Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap(); - /// assert_eq!(m.det(), Ok(-2.0)); + /// assert_eq!(m.det_in_field(), Ok(-2.0)); /// ``` pub fn det_in_field(&self) -> Result where @@ -198,14 +199,14 @@ impl Matrix { let mut multiplier = T::one(); let h = self.height(); let w = self.width(); - for i in 0..h { + for i in 0..(h - 1) { // First check if the row has diagonal element 0, if yes, then swap. if rows[i][i] == T::zero() { let mut zero_column = true; for j in (i + 1)..h { if rows[j][i] != T::zero() { rows.swap(i, j); - multiplier = T::zero() - multiplier; + multiplier = -multiplier; zero_column = false; break; } @@ -248,7 +249,7 @@ impl Matrix { let mut offset = 0; let h = self.height(); let w = self.width(); - for i in 0..h { + for i in 0..(h - 1) { // Check if all the rows below are 0 if i + offset >= self.width() { break; @@ -399,6 +400,87 @@ impl Matrix { } } + /// Returns the inverse of a square matrix. Throws an error if the matrix isn't square. + /// /// # Example + /// ``` + /// use matrix_basic::Matrix; + /// let m = Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap(); + /// let n = Matrix::from(vec![vec![-2.0, 1.0], vec![1.5, -0.5]]).unwrap(); + /// assert_eq!(m.inverse(), Ok(n)); + /// ``` + pub fn inverse(&self) -> Result + where + T: Div, + T: One, + T: PartialEq, + { + if self.is_square() { + // We'll use the basic technique of using an augmented matrix (in essence) + // Cloning is necessary as we'll be doing row operations on it. + let mut rows = self.entries.clone(); + let h = self.height(); + let w = self.width(); + let mut out = Self::identity(h).entries; + + // First we get row echelon form + for i in 0..(h - 1) { + // First check if the row has diagonal element 0, if yes, then swap. + if rows[i][i] == T::zero() { + let mut zero_column = true; + for j in (i + 1)..h { + if rows[j][i] != T::zero() { + rows.swap(i, j); + out.swap(i, j); + zero_column = false; + break; + } + } + if zero_column { + return Err("Provided matrix is singular."); + } + } + for j in (i + 1)..h { + let ratio = rows[j][i] / rows[i][i]; + for k in i..w { + rows[j][k] = rows[j][k] - rows[i][k] * ratio; + } + // We cannot skip entries here as they might not be 0 + for k in 0..w { + out[j][k] = out[j][k] - out[i][k] * ratio; + } + } + } + + // Then we reduce the rows + for i in 0..h { + if rows[i][i] == T::zero() { + return Err("Provided matrix is singular."); + } + let divisor = rows[i][i]; + for entry in rows[i].iter_mut().skip(i) { + *entry = *entry / divisor; + } + for entry in out[i].iter_mut() { + *entry = *entry / divisor; + } + } + + // Finally, we do upside down row reduction + for i in (1..h).rev() { + for j in (0..i).rev() { + let ratio = rows[j][i]; + for k in 0..w { + out[j][k] = out[j][k] - out[i][k] * ratio; + } + } + } + + Ok(Matrix { entries: out }) + } else { + Err("Provided matrix isn't square.") + } + } + // TODO: Canonical forms, eigenvalues, eigenvectors etc. } diff --git a/src/tests.rs b/src/tests.rs index 4163282..388f9c4 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -78,3 +78,24 @@ fn conversion_test() { let c = Matrix::::matrix_from(a); assert_eq!(c, b); } + +#[test] +fn inverse_test() { + let a = Matrix::from(vec![vec![1.0, 2.0], vec![1.0, 2.0]]).unwrap(); + let b = Matrix::from(vec![ + vec![1.0, 2.0, 3.0], + vec![0.0, 1.0, 4.0], + vec![5.0, 6.0, 0.0], + ]) + .unwrap(); + let c = Matrix::from(vec![ + vec![-24.0, 18.0, 5.0], + vec![20.0, -15.0, -4.0], + vec![-5.0, 4.0, 1.0], + ]) + .unwrap(); + + println!("{:?}", a.inverse()); + assert!(a.inverse().is_err()); + assert_eq!(b.inverse(), Ok(c)); +}