0.序
题目大概意思是给出 n 组对应关系, 将它们打乱, 求最后至少有 k 组对应关系正确的打乱方式
思路是从 k 到 n 枚举正确的对应关系个数, 求组合数 Cn(k) * 剩下 n - k 个对应关系完全错误的排法
1.排错公式
当 n 个编号元素放在 n 个编号位置,元素编号与位置编号全都不对应的方法数用 dp[n] 表示
显然 dp[1] = 0, dp[2] = 1;
将 n 个元素错排则
- (1) 将第 1 个元素放到错误的 (n - 1) 个位置
- (2) 假设(1)中将第 1 个元素放在了 k 位置, 那么考虑两种情况:
- 将第 k 个元素放到第 1 个位置, 等价于将第 1 个元素与第 k 个元素交换, 递推考虑错排 n- 2 个元素的方法, 为dp[n - 2]种方法.
- 将第 k 个元素放到 n - 2 个其他的除了第 1 个位置以外的错误位置, dp[n - 1]种方法.
=> dp[n] = (n - 1)(dp[n - 1] + dp[n - 2])
2. 乘法逆元
- 扩展欧几里德算法
用于求出一组 x 和 y 满足方程 ax+by = gcd(a, b)
inline long long ExGCD(long long A, long long B, long long& x, long long& y){
if(A == 0 && B == 0) return -1;
if(B == 0){x = 1, y = 0; return A;}
long long d = ExGCD(B, A % B, y, x);
y -= A / B * x;
return d;
}
因为题目给的要取余的数 MOD 一般是 1e9 + 7 之类的大质数所以 gcd(a, MOD) == 1
即可求出乘法逆元
long long ModReverse(long long a, long long f){
long long x, y, d = ExGCD(a, f, x , y);
if(d == 1){
if(x % f <= 0) return x % f + f;
else return x % f;
}
return -1;
}
费马小定理
MOD 为素数时, a ^ (MOD - 1) == 1(mod MOD), 所以 a ^ (MOD - 2)即为 a 模 MOD 意义下的乘法逆元, 快速幂取模即可.时间复杂度为 O(n) 的递推求 [1, n] 的逆元表(n < MOD)
推导
假设前 i - 1 个数的乘法逆元都已知
由 MOD mod i + (MOD div i) * i = MOD
=> MOD mod i + (MOD div i) * i = 0 (mod MOD)
=> (MOD div i) * i = - (MOD mod i) (mod MOD)
=> i ^ -1 = - (MOD div i) / (MOD mod i)
=> i ^ -1 = - (MOD div i) * (MOD mod i) ^ -1
因为 MOD mod i 小于 i 所以 (MOD mod i) ^ -1 是已知的, 即可得到以下递推关系
inv[1] = 1;
for(int i = 2; i <= 10000; i++) inv[i] = inv[MOD % i] * (MOD - MOD / i) % MOD;
3. 组合数
直接公式:3. AC代码
逆元表写的, 比较喜欢打表预处理
#include <iostream>
#include <cstdio>
using namespace std;
const long long MOD = 1e9 + 7;
long long dp[10007], inv[10007];
int main()
{
int T, n, k;
long long C, ans;
scanf("%d", &T);
dp[1] = 0, dp[2] = 1, inv[1] = 1;
for(int i = 3; i <= 10000; i++) dp[i] = ((dp[i - 1] + dp[i - 2]) % MOD) * (i - 1) % MOD;
for(int i = 2; i <= 10000; i++) inv[i] = inv[MOD % i] * (MOD - MOD / i) % MOD;
for(int cas = 1; cas <= T; cas++){
scanf("%d%d", &n, &k);
ans = 1, C = n;
for(int i = 2; i <= k; i++) C = ((C * (n - i + 1) % MOD) * inv[i]) % MOD;
for(int i = k; i < n; i++){
ans = (ans + (C * dp[n - i]) % MOD ) % MOD;
C = ((C * (n - i) % MOD) * inv[i + 1]) % MOD;
}
printf("%I64d\n", ans);
}
return 0;
}