【分治】逆序对计算-Count Inversion

逆序对定义

如果存在正整数i,j使得1\leq i<j\leq n,而且A[i]>A[j],则<A[i],A[j]>这个有序对称为A的一个逆序对,也称作逆序数。

例题1 - LeetCode剑指 Offer 51. 数组中的逆序对

题目

在数组中的两个数字,如果前面一个数字大于后面的数字,则这两个数字组成一个逆序对。输入一个数组,求出这个数组中的逆序对的总数。

输入: [7,5,6,4]
输出: 5
限制:0 <= 数组长度 <= 50000
思路

逆序对:<7,5>,<7,6>,<7,4>,<5,4>,<6,4>

1. 暴力枚举

数组中每个元素A[i]与其后面每个元素A[j](j>i)进行比较,若A[i]>A[j]则逆序对数量加1。
时间复杂度:O(n^2)

2. 分治

若数组元素个数为0或1,则该数组逆序对数量为0;若数组元素为有序,则该数组逆序对数量为0。可以发现,逆序对数量其实就是将无序数组排为有序后,数组元素交换的次数。

使用分治算法,递归将数组进行二分(low ~ middle 和 middle+1 ~ high),直至为仅剩1个元素。此时逆序对数量为0。再将数组(low ~ middle 和 middle+1 ~ high)两两依次合并,合并时若左半部分数组中的元素A[i],(low\leq i\leq mid),则逆序对数量增加mid-i+1
以题目为例:[7, 5, 6 ,4 ]

逆序对.jpg

时间复杂度:O(nlog n)

具体代码:

#include <bits/stdc++.h>
using namespace std;
vector<int> tmp;
long long int sum = 0;
long long int merge(vector<int> &nums,int low,int mid,int high){
    for (int i=low;i<=high;i++)
        tmp[i] = nums[i];
    int i = low;
    int j = mid+1;
    int k = low;
    while (i<=mid && j<=high){
        if (tmp[i]<=tmp[j])
            nums[k++] = tmp[i++];
        else {
            nums[k++] = tmp[j++];   
            sum += (mid-i+1);   
        }
    }
    while(i<=mid)
        nums[k++] = tmp[i++];
    while(j<=high)
        nums[k++] = tmp[j++];   
    return sum;
}

long long int mergeSort(vector<int> &nums,int low,int high){
    if (low == high)
        return 0;
    int mid=low+(high-low)/2;
    mergeSort(nums,low,mid);
    mergeSort(nums,mid+1,high);
    return merge(nums,low,mid,high);
}

int main(){
    int n,num;
    vector<int> nums;
    scanf("%d",&n);
    for (int i=0;i<n;i++){
        scanf("%d",&num);
        nums.push_back(num);
    }
    if (nums.size() < 2){
        cout << sum << endl;
        return 0;
    }
    tmp.resize(nums.size(),0);
    printf("%lld",mergeSort(nums,0,nums.size()-1));
    return 0;
}



例题2 - Count Inversion

Problem Description

Recall the problem of finding the number of inversions. As in the course, we are given a sequence of n numbers a_1,a_2,...a_n and we define an inversion to be a pair i<j such that a_i>a_j
We motivated the problem of counting inversions as a good measure of how different two orderings are. However, one might feel that this measure is too sensitive. Let's call a pair a significant inversion if i<j and a_i>3a_j. Give an O(nlog n)algorithm to count the number of significant inversions between two orderings.
The array contains N elements (1 \leq N \leq 100,000). All elements are in the range from 1 to 1,000,000,000.

Input

The first line contains one integer N , indicating the size of the array. The second line contains N elements in the array.
· 50% test cases guarantee that N<100

Output

Output a single integer which is the number of pairs of significant inversions.

Sample Inout

6
13 8 5 3 2 1

Sample Output

6

题意与例题1相同,只不过增加一个限定条件:a_i>3a_j,但此时使用分治算法后,将数组(low~middle 和 middle+1~high)两两依次合并时,合并时若左半部分数组中的元素A[i],(low \leq i \leq mid)大于右半部分数组元素A[j],(mid+1 \leq j \leq high),且A[pos]>3*A[j],(i \leq pos \leq mid) 则逆序对数量增加mid-pos+1。即,不能仅仅通过比较A[i],A[j]就增加逆序对数量,如进行[5,8,13][2,3]合并时,虽然5<3*2但是5后面的元素还存在大于3*2的元素,所以此时需要遍历左半部分数组,直至出现第一个大于三倍A[j]的元素。因此需在原代码基础上进行修改。
#include <bits/stdc++.h>
using namespace std;
vector<int> tmp;
long long int sum = 0;
long long int merge(vector<int> &nums,int low,int mid,int high){
    for (int i=low;i<=high;i++)
        tmp[i] = nums[i];
    int i = low;
    int j = mid+1;
    int k = low;
    while (i<=mid&&j<=high){
        if (tmp[i] <= tmp[j])
            nums[k++] = tmp[i++];
        else {
            int pos = i;
            while (pos <= mid){
                if (tmp[pos] > (long long)3*tmp[j]){//此处为了避免乘以3后超出范围采用long long强制转换。(OJ没满分就因为这。)
                    sum += (mid-pos+1);
                    break;
                }
                pos++;
            }   
            nums[k++] = tmp[j++];   
        }
            
    }
    while(i<=mid)
        nums[k++] = tmp[i++];
    while(j<=high)
        nums[k++] = tmp[j++];   
    return con;
}

long long int mergeSort(vector<int> &nums,int low,int high){
    if (low == high)
        return 0;
    int mid = low+(high-low)/2;
    mergeSort(nums,low,mid);
    mergeSort(nums,mid+1,high);
    return merge(nums,low,mid,high);
}

int main(){
    int n,num;
    vector<int> nums;
    scanf("%d",&n);
    for (int i=0;i<n;i++){
        scanf("%d",&num);
        nums.push_back(num);
    }
    if (nums.size() < 2){
        cout << sum << endl;
        return 0;
    }
    tmp.resize(nums.size(),0);
    printf("%lld",mergeSort(nums,0,nums.size()-1));
    return 0;
}
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。