算法
函数的增长
渐近记号
用来描述算法渐近运行时间的记号,根据定义域为自然数集的函数来定义。这样的记号对描述最坏情况运行时间函数是方便的,因为该函数通常只定义在整数输入规模上。
对一个给定的函数,用来表示一下函数的集合:
={:存在正常量,使得对所有,有}
若存在正常量和使得对足够大的n,函数能“夹入”与之间,则属于集合,因为是一个集合,所以可以记“”,以指出的成员。作为替代,我们通常记“”以表达相同的概念。
我们称是的一个渐进紧确界(asymptotically tight bound)。
一般来说,对任意多项式,其中为常量且,我们有。
因为任意常量是一个0阶多项式,所以可以把任意常量函数表示成或者。
记号渐近地给出一个函数的上界和下界。当只有一个渐进上界时,使用记号。对于给定的函数,用来表示以下函数的集合:
={:存在常量,使得对所有时,有}
正如记号提供了一个函数的渐近上界,记号提供了渐近下界。对于给定的函数,用来表示以下函数的集合:
={:存在正常量,使得对所有时,有}
定理3.1:对任意两个函数和,我们有,当且仅当:。
由记号提供的渐近上界可能是也可能不是渐近紧确的。界是渐近紧确的,但是界却不是。我们使用o记号来表示一个非渐进紧确的上界。
形式地定义为以下集合:
={:对任意正常量,存在常量,使得对所有,有}
记号与记号的定义类似。主要区别在于是中,界对某个常量成立,但在中,界对所有常量都成立。直观上,在记号中,当趋于无穷时,函数相对于来说变得微不足道了,即:
的关系类似于的关系。我们使用来定义一个非渐近紧确的下界。定义它的一种方式是:
然而我们形式化地定义为以下集合:
={:对任意正常量,存在常量,使得对所有的时,有}
关系蕴含着:
也就是说,如果这个极限存在,当趋于无穷时,来说变得任意大了。
比较各种函数
实数的许多关键性质也适用于渐近比较。下面假定渐近为正。
传递性自反性
对称性
转置对称性
因为这些性质对渐近记号成立,所以可以在两个函数的渐近比较和两个实数的比较之间做一种类比。
三分性 对任意两个实数,下列三种情况恰有一种必须成立:
虽然任意两个时序都可以进行比较,但是不是所有的函数都可渐近比较。也就是说,对两个函数,也许和都不成立。
标准记号与常用函数
单调性
若蕴含,则函数是单调递增的。类似的,若蕴含,则函数是单调递减的。若蕴含,则函数是严格单调递增的。类似的,若蕴含,则函数是严格单调递减的。
向上取整和向下取整
对于任意实数,我们用表示的向下取整,并用表示的向上取整。
对所有实数:
对任意整数n,
对任意实数 和整数 ,
向下取整函数是单调递增的,向上取整函数 也是单调递增的。
模运算
对任意整数和任意正整数,的值就是商的余数。
结果有:
多项式
给定一个非负整数,次多项式为具有以下形式的一个函数:
其中常量是多项式的系数且。一个多项式渐近正的当且仅当。对于一个次渐近正的多项式,有。对任意实常数,函数单调递增,对任意实常量,函数单调递减。若对某个常量,有,则称函数是多项式有界的。
多项式
对所有实数,我们有以下恒等式:
对所有 ,函数单调递增。方便时,我们假定。
可以通过以下事实使 多项式与指数的增长率相关联。多所有使得 的实常量 ,有据此可得
因此, 任意底大于1的指数函数比任意多项式函数增长得快。</mark>
使用来表示自然对数函数的底,对所有实数,我们有
对所有实数,我们有不等式
对所有的,我们有:
对数
我们使用下面的记号:
如果常量 ,那么对于,函数 是严格递增的。
对于所有实数,有
其中,在上面的每个等式中,对数的底不为1。
因此, 任意正的多项式函数都比任意对数函数增长的快。
阶乘
记号,定义为对整数,有:
阶乘函数一个弱上界是 ,因为在阶乘中, 项的每项最多为。 斯特林(Stirling)近似公式给出了一个更紧确的上界和下界,其中是自然对数的底。
对所有,下面的等式也成立:
其中:
多重函数
我们使用记号来表示函数重复i次作用与处置上。形式化地,假设为实数集上的一个函数。对非负整数,我们递归地定义
多重对数函数
我们使用记号来表示多重对数函数,下面给出它的定义。假设定义如上,其中。因为非正数的对数无定义,所以只有在时,才有定义。定义多重对数函数为
多重对数函数是一个增长非常慢的函数。
菲波那切数
使用下面的递归式来定义菲波那切数:
因此,菲波那切数都是前面两个数之和,产生的序列为
菲波那切数与 黄金分割率 以及共轭数有关,它们是下列方程的两个根:
并由下面的公式给出:
线性查找问题
输入:
输出:LINEAR-SEARCH(A, v) for i = 1 to A.length if v == A[i] return i return -1
java实现:
public static int linearSearch(int[] srcArr, int val) { for (int i = 0; i < srcArr.length; i++) { if (srcArr[i] == val) return i; } return -1; }
如果数组已经排好序,就可以将该序列的中点与进行比较根据比较的结果,原序列中有一半就可以不用再进一步的考虑了。二分查找算法重复这个过程,每次都将序列剩余部分的规模减半。
非递归方式:BINARY-SEARCH(A, v) low = 1 high = A.length while low <= high middle = (low + high) / 2 if A[middle] < v low = middle + 1 elseif A[middle] > v high = middle - 1 else return middle return -1
java实现:
/** * 利用非递归形式的二分查找法在数组中寻找特定的值 * @param srcArray 被搜索的整形数组 * @param val 待查找的值 * @return 若该值在数组中,返回该值对应的数组索引。否则返回-1 */ public static int binarySearch(int[]srcArray, int val){ //低位“指针” int low = 0; //高位“指针” int high = srcArray.length - 1; //如果low ≤ high则进行查找, // 因为无论数组元素为偶数个还是奇数个,当要查找的值不在数组中时最后一步查找情况是low和high重合,此时middle=low=high, // 如果srcArray[middle]>val,执行low = middle + 1,此时low>high; //如果srcArray[middle]<val,执行high = middle - 1;,此时low>high; while(low <= high){ int middle = (low + high) >> 1; //当数组中间值小于待查找值,该值“可能”在数组右半侧,并且索引middle处的值已经判断过,所以low=middle+1, //并且如果low=middle,在[srcArray[low], srcArray[high]]会陷入死循环 if(srcArray[middle] < val){ low = middle + 1; }else if(srcArray[middle] > val){ high = middle - 1; //找到待查找值,返回该值对应数组索引 }else{ return middle; } } //当待查找值不在数组中时返回-1 return -1; }
递归方式:
BINARY-SEARCH(A, low, high, val) while low <= high middle = (low + high) / 2 if A[middle] == val return middle elseif A[middle] < val return BINARY-SEARCH(A, middle + 1, high, val) else return BINARY-SEARCH(A, low, middle - 1, val) return -1
java实现:
public static int binarySearch(int[]srcArray, int low, int high, int val){ while (low <= high) { int middle = (low + high) >> 1; if (srcArray[middle] == val) { return middle; } else if (srcArray[middle] < val) return binarySearch(srcArray, middle + 1, high, val); else return binarySearch(srcArray, low, middle - 1, val); } return -1; }
排序
输入:
输出:插入排序
对于少量元素的排序,它是一个有效的算法。
如图所示,一副完整的牌就像一个数组,手中的牌是已经按从小到大排好,这时你从桌子上的牌堆中取出一张牌,你要做的就是将这张牌插入到手中的牌里。手中原来有牌,从桌子上取出一张牌是7,和最大的牌10比,,10往后移一个位置,比,,那么就将7插入刚才10空出的位置,以此类推。
INSERTION-SORT(A) for j = 2 to A.length key = A[j] //Insert A[j] into the sorted sequence A[1..j - 1]. i = j - 1 while i > 0 and A[i] > key A[i + 1] = A[i] i = i - 1 A[i + 1] = key
java实现
/** * 对一个整型数组按从小到大的顺序排序 * @param arr 待排序的数组 */ public static void insertionSort(int[] arr) { for (int i = 0; i < arr.length; i++) { int key = arr[i]; int j = i - 1; //将当前值key与已排好部分arr[0..i-1]中的值挨个比较大小 //如果key大于已排好数组中arr[j],则将arr[j]往后“移一位”,将当前位置腾出来,保存key或前面移过来的值 while (j >= 0 && arr[j] > key) { arr[j + 1] = arr[j]; j--; } //①如果key不小于已排好数组最大值(即已排好部分最后一个值),将key放到arr[j + 1]即arr[i] //相当于将arr[i]拿出来比较一下,发现arr[i]不小于已排好数组中最大值,再讲arr[i]放回去; //②如果key小于已排好数组最大值,那么经过while循环当前的arr[j]是第一个小于key的数,所以将 //key放在arr[j + 1]已腾出来的位置,arr[j + 1..arr[i]]里面的元素已经依次向右移动一个位置。 arr[j + 1] = key; } }
插入排序的:
- 最好运行时间是
- 最坏运行时间是(读作“theta n 平方”)。
事实上,元素A[1, j - 1]就是原来1到j - 1的元素,但是现在已按顺序排列。我们把A[1..j-1]的这些性质形式地表示为一个循环不变式。
循环不变式主要用来帮助我们理解算法的正确性。关于循环不变式,我们必须证明三条性质:
- 初始化:循环的第一次迭代之前,它为真。
- 保持:如果循环的某次迭代之前它为真,那么下次迭代之前塔仍然为真。
- 终止:在循环终止时,不变式为我们提供一个有用的性质,该性质有助于证明算法是正确的。
归并排序
归并排序算法完全遵循分治模式,直观上其操作如下:
- 分解:分解带排序的n个元素的序列成各具n/2个元素的两个子序列。
- 解决:使用归并排序递归地排序两个子序列。
- 合并:合并两个已经排序的子序列以产生已排序的答案。
MERGE-SORT(A, p, r) if p < r q = ⌊(p + r) / 2⌋ MERGE-SORT(A, p, q) MERGE-SORT(A, q, r) MERGE(A, p, q, r) MEARGE(A, p, q, r) n1 = q - p + 1 n2 = r - q let L[1..n1 + 1] and R[1..n2 + 1] be new arrays for i = 1 to n1 L[i] = A[p + i - 1] for j = 1 to n2 R[j] = A[q + j] i = 1 j = 1 for k = p to r if i != n1 and (j == n2 or L[i] ≤ R[j]) A[k] = L[i] i = i + 1 else A[k] = R[j] j = j + 1
java实现
/** * 给定一个数组,起始位置,终止位置,对数组[起始位置..终止位置]按从小到大排序 * @param arr 待排序数组 * @param p 排序起始位置 * @param r 排序终止位置 */ public static void mergeSort(int[] arr, int p, int r){ if (p < r){ //取中间位置,将数组分为左右两部分 int q = (p + r) / 2; //递归对左数组进行排序 mergeSort(arr, p, q); //递归对右数组进行排序 mergeSort(arr, q + 1, r); //调用方法merge将数组中arr[p..r]部分进行排序 merge(arr, p, q, r); } } /** * 给定一个数组arr,给定三个索引参数,p、q、r,满足p≤q≤r,将arr[p..r]分成两个数组 * arr[p..q]和arr[q+1..r],再融合两个数组的过程中对数组进行排序,融合后的数组即排好序的数组 * @param arr 待排序的数组 * @param p 待排序数组首位索引 * @param q 待排序数组中间索引 * @param r 待排序数组摸位索引 */ public static void merge(int[] arr, int p, int q, int r){ int lLen = q - p + 1; int rLen = r - q; //用于接收左半数组 int[] lArr = new int[lLen]; //用于接收右半数组 int[] rArr = new int[rLen]; //将arr[p..q]复制到lArr System.arraycopy(arr, p, lArr, 0, lLen); //将arr[q+1..r]复制到rArr System.arraycopy(arr, q + 1, rArr, 0, rLen); int i = 0, j = 0; for (int k = p; k <= r; k++) { //取lArr中的值首先需要满足lArr数组索引没有越界,这个前提下有两种情况, //①rLen索引已到rLen.length,即rLen中的值都被取出 //②两个数组中都有值,并且lArr[i] <= rArr[j] if (i != lLen && (j == rLen || lArr[i] <= rArr[j])){ arr[k] = lArr[i]; i++; } else { arr[k] = rArr[j]; j++; } } }
选择排序
考虑排序存储在数组A中的n个数:首先找出A中的最小元素并将其与A[1]中的元素进行交换。接着,找出A中次最小元素并将其与A[2]中的元素进行交换。对A中的前n-1个元素按该方式继续。
SELECTION-SORT for i = 1 to A.length - 1 for j = i to A.lenth if A[j] < A[i] //change position temp = A[i] A[i] = A[j] A[j] = temp
java实现:
public static void selectionSort(int[] arr) { for (int i = 0; i < arr.length - 1; i++) { for (int j = i; j < arr.length; j++) { if (arr[j] < arr[i]) { //①a ^ a = 0;②a ^ 0 = a;③a ^ b ^ c = a ^ (b ^ c) = (a ^ b) ^ c; arr[i] = arr[i] ^ arr[j]; // arr[j] = arr[i] ^ arr[j] ^ arr[j] = arr[i] arr[j] = arr[i] ^ arr[j]; // arr[i] = arr[i] ^ arr[j] ^ arr[j] = arr[j] ^ arr[i] ^ arr[i] = arr[j] arr[i] = arr[i] ^ arr[j]; } } } }
分治策略
分治法:将原问题分解为几个规模较小但类似于原问题的子问题,递归地求解这些子问题,然后再合并这些子问题的解来建立原问题的解。
分治模式在每层递归时都有三个步骤:
- 分解原问题为若干子问题,这些子问题是原问题的规模较小的实例。
- 解决这些子问题,递归地求解各个子问题。然而,若子问题的规模足够小,则直接求解。
- 合并这些子问题的解成原问题的解。
递归式:递归式与分治方法是紧密相关的,因为使用递归式可以很自然地刻画分治算法的运行时间。一个递归式(recurrence)就是一个等式或一个不等式。
本章介绍三种求解递归式的方法,即得出算法渐近界的方法。
- 代入法 我们猜测一个界,然后用数学归纳法证明这个界的正确性。
- 递归树法 将递归式转换为一颗树,其节点表示不同层次的递归调用产生的代价。然后采用边界和技术求解递归式。
- 主方法 可求解如下面公式的递归式的界:
其中是一个给定的函数。这种形式的递归式很常见,它刻画了一个这样的分治算法:生成个子问题,每个子问题的规模是原问题的,分解和合并共花费时间。
最大子数组问题
有一整形数组,找出中和为最大的非空连续子数组。我们称这样的连续子数组为最大连续子数组。
暴力求解
遍历所有可能的数组组合,找出其中和最大的。
FIND-MAXIMUM-SUBARRAY(A, low, high) sum = -∞ for i = 1 to A.length tempSum = 0 for j = i to A.length tempSum = tempSum + A[j] if tempSum > sum sum = tempSum max-left = i max-right = j return (max-left, max-right, left-sum + right-sum)
java实现
public static int[] findMaximumSubarray(int[] arr, int low, int high) { //假设arr[0]就是最大连续子数组 int sum = arr[0]; int maxLeft = 0; int maxRight = 0; for (int i = 0; i < arr.length; i++) { int tempSum = 0; for (int j = i; j < arr.length; j++) { tempSum += arr[j]; if (tempSum > sum) { sum = tempSum; maxLeft = i; maxRight = j; } } } return new int[]{maxLeft, maxRight, sum}; }
分治方法
过程FIND-MAX-CORSSING-SUBARRAY接受数组和下标、、作为输入,返回一个下标元祖规定跨越中点的最大子数组的边界,并返回最大子数组中值的和。
FIND-MAXIMUM-SUBARRAY(A, low, high) if high == low return (low, high, A[low]) else mid = ⌊(low + high) / 2⌋ (left-low, left-high, left-sum) = FIND-MAXIMUM-SUBARRAY(A, low, mid) (right-low, right-high, right-sum) = FIND-MAXIMUM-SUBARRAY(A, mid + 1, high) (cross-low, cross-high, cross-sum) = FIND-MAX-CORSSING-SUBARRAY(A, low, mid, high) if left-sum ≥ right-sum and left-sum ≥ cross-sum return (left-low, left-high, left-sum) elseif right-sum ≥ left-sum and right-sum ≥ cross-sum return (right-low, right-high, right-sum) else return (cross-low, cross-high, cross-sum) FIND-MAX-CORSSING-SUBARRAY(A, low, mid, high) left-sum = -∞ sum = 0 for i = mid downto low sum = sum + A[i] if sum > left-sum left-sum = sum max-left = i right-sum = -∞ sum = 0 for j = mid + 1 to high sum = sum + A[j] if sum > right-sum right-sum = sum max-right = j return (max-left, max-right, left-sum + right-sum)
java实现:
/** * 该方法接收一个数组和low、high下标,找出其范围内的最大子数组 * @param arr 被查找的数组 * @param low 低位下标 * @param high 高位下标 * @return 最大子数组的起始位置,终止位置、和 */ public static int[] findMaximumSubarray(int[] arr, int low, int high) { //递归触底反弹,子数组只有一个元素,所以arr[low]本身就是最大子数组 if (low == high) return new int[]{low, high, arr[low]}; else { int mid = (low + high) / 2; int[] leftArr = findMaximumSubarray(arr, low, mid); int[] rightArr = findMaximumSubarray(arr, mid + 1, high); int[] crossingArr = findMaxCrossingSubarray(arr, low, mid, high); if (leftArr[2] >= rightArr[2] && leftArr[2] >= crossingArr[2]) return leftArr; else if (rightArr[2] >= leftArr[2] && rightArr[2] >= crossingArr[2]) return rightArr; else return crossingArr; } } /** * 该方法接收一个数组arr和下标low,mid,high为输入,返回一个下标元组划定跨越中点的最大子数组的边界, * 并返回最大子数组中值的和。 * @param arr 被查询的数组 * @param low 低位下标 * @param mid 中间位置下标,最大子数组跨越该点 * @param high 高位下标 * @return 最大子数组的起始位置,终止位置,和 */ public static int[] findMaxCrossingSubarray(int[] arr, int low, int mid, int high) { int maxLeft = mid; int maxRight = mid + 1; int leftSum = arr[mid]; int sum = 0; for (int i = mid; i >= low; i--) { sum += arr[i]; if (sum > leftSum) { leftSum = sum; maxLeft = i; } } int rightSum = arr[mid + 1]; sum = 0; for (int i = mid + 1; i <= high; i++) { sum += arr[i]; if (sum > rightSum) { rightSum = sum; maxRight = i; } } return new int[]{maxLeft, maxRight, leftSum + rightSum}; }
如果子数组A[low..high]包含n个元素,则调用FIND-MAX-CROSSING-SUBARRAY(A, low, mid, high)花费时间。
初始调用FIND-MAXIMUM-SUBARRAY(A, 1, A.length)会求出A[1..n]的最大子数组。线性非分治方法
从数组的左边界开始,由左至右处理,记录到目前为止已经处理过的最大子数组。若已知A[1..j]的最大子数组,基于如下性质将扩展为A[1..j+1]的最大子数组:A[1..j+1]的最大子数组要么是A[1..j]的最大子数组,要么是某个子数组A[i..j+1](1 ≤ i ≤ j+1)。在已知A[1..j]的最大子数组的情况下,可以在线性时间内找出形如A[i..j+1]的最大子数组。
有问题,但还是先记录下,/** * 该方法接收一个数组和low、high下标,找出其范围内的最大子数组 * @param arr 被查找的数组 * @param low 低位下标 * @param high 高位下标 * @return 最大子数组的起始位置,终止位置、和 */ public static int[] findMaximumSubarray(int[] arr, int low, int high) { int maxLeft = low; int maxRight = low; int sum = arr[low]; int tempSum; for (int i = 0; i < arr.length - 1; i++) { tempSum = 0; for (int j = i + 1; j >= 0 ; j--) { tempSum += arr[j]; if (tempSum > sum) { maxLeft = j; maxRight = i + 1; sum = tempSum; } } } return new int[]{maxLeft, maxRight, sum}; }
矩阵乘法的Strassen算法
若和是的方阵,则对,定义乘积中的元素为:
我们需要计算 个矩阵元素,每个元素是个值得和。根据定义下面过程接收矩阵,返回它们的乘积——矩阵 。假设矩阵都有一个属性,代表该矩阵的行数。
SQUARE-MATRIX-MULTIPLAY(A, B) n = A.rows for i = 1 to n cij = 0 for j = 1 to n for k = 1 to n cij = cij + aik * bkj return C
由于三重for循环的每一重都恰好执行n步,而第三重的加法需要常量时间,因此过程SQUARE-MATRIX-MAULPLAY花费时间。
java模拟
模拟矩阵的类,用二维数组存储值,重写了toString()方法:public class Matrix { private int rows; //矩阵的行数 private int cols; //矩阵的列数 private double[][] matrixArray; //代表矩阵的二维数组 public Matrix(int rows, int cols) { this.rows = rows; this.cols = cols; matrixArray = new double[rows][cols]; } public Matrix(double[][] matrixArray) { rows = matrixArray.length; cols = matrixArray[0].length; this.matrixArray = matrixArray; } public int getRows() { return rows; } public void setRows(int rows) { this.rows = rows; } public int getCols() { return cols; } public void setCols(int cols) { this.cols = cols; } public double[][] getMatrixArray() { return matrixArray; } public void setMatrixArray(double[][] matrixArray) { this.matrixArray = matrixArray; } @Override public String toString() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < matrixArray.length; i++) { for (int j = 0; j < matrixArray[i].length; j++) { sb.append(matrixArray[i][j] + "\t"); if (j == matrixArray[i].length - 1) sb.append("\n"); } } return sb.toString(); } }
算法类,包含一个求两个矩阵积的静态方法:
public class Algorithms { /** * 接收两个矩阵A, B返回两者的乘积 * @param matrixA 乘数A矩阵 * @param matrixB 乘数B矩阵 * @return 两个矩阵的乘积 */ public static Matrix squareMatrixMultiply(Matrix matrixA, Matrix matrixB) { double[][] matrixAArray = matrixA.getMatrixArray(); double[][] matrixBArray = matrixB.getMatrixArray(); int rows = matrixA.getRows(); int cols = matrixB.getCols(); double sum = 0; double[][] matrixCArray = new double[rows][cols]; for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { sum = 0; for (int k = 0; k < cols; k++) { sum = sum + matrixAArray[i][k] * matrixBArray[k][j]; } matrixCArray[i][j] = sum; } } return new Matrix(matrixCArray); } }
测试类:
public class TestAlgorithms { public static void main(String[] args) { double[][] matrixAArray = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; double[][] matrixBArray = {{12.3, 12, 56}, {34, 456, 234.2}, {93, 3434, 1314}}; Matrix A = new Matrix(matrixAArray); Matrix B = new Matrix(matrixBArray); Matrix C = Algorithms.squareMatrixMultiply(A, B); System.out.println(C); } }
输出:
359.3 11226.0 4466.4
777.2 22932.0 9279.0
1195.1 34638.0 14091.6直接的递归分治算法:
SQUARE-MATRIX-MULPLAY-RECURSIVE(A, B) n = A.rows let C be a new n×n matrix if n == 1 c11 = a11 · b11 else C11 = SQUARE-MATRIX-MULPLAY-RECURSIVE(A11, B11) + SQUARE-MATRIX-MULPLAY-RECURSIVE(A12, B21) C12 = SQUARE-MATRIX-MULPLAY-RECURSIVE(A11, B12) + SQUARE-MATRIX-MULPLAY-RECURSIVE(A12, B22) C21 = SQUARE-MATRIX-MULPLAY-RECURSIVE(A21, B11) + SQUARE-MATRIX-MULPLAY-RECURSIVE(A22, B21) C22 = SQUARE-MATRIX-MULPLAY-RECURSIVE(A21, B12) + SQUARE-MATRIX-MULPLAY-RECURSIVE(A22, B22) return C
SQUARE-MATRIX-MULPLAY-RECURSIVE运行时间递归式:
,说明简单的分治算法并不优于直接的SQUARE-MATRIX-MULPLAY算法。
java模拟
模拟矩阵的类继续使用上一个方法中的Matrix
。
算法类:包含复制二维数组、矩阵加法、合并矩阵、矩阵乘法四个静态方法public class Algorithms { /** * 接收两个矩阵A, B返回两者的乘积 * @param matrixA 乘数A矩阵 * @param matrixB 乘数B矩阵 * @return 两个矩阵的乘积 */ public static Matrix squareMatrixMultiplyRecursive(Matrix matrixA, Matrix matrixB) { double[][] A = matrixA.getMatrixArray(); double[][] B = matrixB.getMatrixArray(); int rows = matrixA.getRows(); double[][] C = new double[rows][rows]; if (rows == 1) { C[0][0] = A[0][0] * B[0][0]; return new Matrix(C); } else { int count = rows >> 1; double[][] A11Arr = arrayCopy(A, 0, 0, count); double[][] A12Arr = arrayCopy(A, 0, count, count); double[][] A21Arr = arrayCopy(A, count, 0, count); double[][] A22Arr = arrayCopy(A, count, count, count); double[][] B11Arr = arrayCopy(B, 0, 0, count); double[][] B12Arr = arrayCopy(B,0, count, count); double[][] B21Arr = arrayCopy(B, count, 0, count); double[][] B22Arr = arrayCopy(B, count, count, count); Matrix C11 = add(squareMatrixMultiplyRecursive(new Matrix(A11Arr), new Matrix(B11Arr)), squareMatrixMultiplyRecursive(new Matrix(A12Arr), new Matrix(B21Arr))); Matrix C12 = add(squareMatrixMultiplyRecursive(new Matrix(A11Arr), new Matrix(B12Arr)), squareMatrixMultiplyRecursive(new Matrix(A12Arr), new Matrix(B22Arr))); Matrix C21 = add(squareMatrixMultiplyRecursive(new Matrix(A21Arr), new Matrix(B11Arr)), squareMatrixMultiplyRecursive(new Matrix(A22Arr), new Matrix(B21Arr))); Matrix C22 = add(squareMatrixMultiplyRecursive(new Matrix(A21Arr), new Matrix(B12Arr)), squareMatrixMultiplyRecursive(new Matrix(A22Arr), new Matrix(B22Arr))); return combine(C11, C12, C21, C22); } } /** * 复制二维数组 * @param srcArr 源数组 * @param x 从二维数组中第几个一维数组开始复制 * @param y 从那个一维数组的第几个元素开始复制 * @param count 连续的作用到几个一维数组,每个一维数组复制几个值 * @return 复制好的二维数组 */ public static double[][] arrayCopy(double[][] srcArr, int x, int y, int count) { double[][] destArr = new double[count][count]; for (int i = 0; i < count; i++) { for (int j = 0; j < count; j++) { destArr[i][j] = srcArr[x + i][y + j]; } } return destArr; } /** * 求两个矩阵的和 * @param A 加数矩阵A * @param B 加数矩阵B * @return 两个矩阵的和 */ public static Matrix add(Matrix A, Matrix B) { double[][] aArr = A.getMatrixArray(); double[][] bArr = B.getMatrixArray(); double[][] cArr = new double[aArr.length][bArr[0].length]; for (int i = 0; i < aArr.length; i++) { for (int j = 0; j < aArr[i].length; j++) { cArr[i][j] = aArr[i][j] + bArr[i][j]; } } return new Matrix(cArr); } /** * 将四个矩阵合并为一个矩阵 * @param A11 子矩阵 * @param A12 子矩阵 * @param A21 子矩阵 * @param A22 子矩阵 * @return 合并后的矩阵 */ public static Matrix combine(Matrix A11, Matrix A12, Matrix A21, Matrix A22) { double[][] a11Arr = A11.getMatrixArray(); double[][] a12Arr = A12.getMatrixArray(); double[][] a21Arr = A21.getMatrixArray(); double[][] a22Arr = A22.getMatrixArray(); int rowsA = a11Arr.length; int colsA = a11Arr[0].length; int rowsB = a12Arr.length; int colsB = a12Arr[0].length; int rowsC = a21Arr.length; int colsC = a21Arr[0].length; int rowsD = a22Arr.length; int colsD = a22Arr[0].length; double[][] resultArr = new double[rowsA + rowsC][colsA + colsB]; for (int i = 0; i < rowsA; i++) { for (int j = 0; j < colsA; j++) { resultArr[i][j] = a11Arr[i][j]; } } for (int i = 0; i < rowsB; i++) { for (int j = 0; j < colsB; j++) { resultArr[i][colsA + j] = a12Arr[i][j]; } } for (int i = 0; i < rowsC; i++) { for (int j = 0; j < colsC; j++) { resultArr[rowsA + i][j] = a21Arr[i][j]; } } for (int i = 0; i < rowsD; i++) { for (int j = 0; j < colsD; j++) { resultArr[rowsA + i][colsA + j] = a22Arr[i][j]; } } return new Matrix(resultArr); } }
测试类:
public class TestAlgorithms { public static void main(String[] args) { double[][] matrixAArray = {{1, 2, 3, 3}, {4, 5, 6, 6}, {7, 8, 9, 9}, {1, 1, 1, 1}}; double[][] matrixBArray = {{12.3, 12, 56, 1}, {34, 456, 234.2, 1}, {93, 3434, 1314, 1}, {1, 1, 1, 1}}; Matrix A = new Matrix(matrixAArray); Matrix B = new Matrix(matrixBArray); Matrix C = Algorithms.squareMatrixMultiplyRecursive(A, B); System.out.println(C); } }
输出:
362.3 11229.0 4469.4 9.0
783.2 22938.0 9285.0 21.0
1204.1 34647.0 14100.6 33.0
140.3 3903.0 1605.2 4.0