Skip to content

中国剩余定理

用来求解这样的方程组:

\(\left\{ \begin{aligned} x\equiv a_1 \pmod {m_1}\\ x\equiv a_2 \pmod {m_2}\\ ...\\ x\equiv a_k \pmod {m_k} \end{aligned} \right.\)

所有\(m\)互质。

做法

定义\(M=\prod_{i=1}^k m_i,M_i=\frac M {m_i}\),则

\(\gcd(M_i,m_i)=1\),因为所有的\(m\)都互质。

此时,定义\(M_it_i\equiv 1\pmod {m_i},t_i=M_i^{-1}\pmod {m_i}\)

\(t_i\)可以用扩欧求出。

所以:\(a_iM_it_i\equiv a_i\pmod{m_i}\)

又因为对于\(i\not = j,M_i=km_j\),所以\(a_iM_it_i\equiv 0\pmod {m_j}\)

因此可行解为\(\sum_{i=1}^ka_iM_it_i\)

P1495

模板题。

#include<iostream>
#include<cstdio>
#include<cstring>
#define int long long 
using namespace std;
const int N=1e4+10;
int n,M,_M,g,x,y,t,ans;
int a[N],m[N];
void exgcd(int a,int b,int &d,int &x,int &y);   
signed main(){
    M=1;scanf("%lld",&n);
    for(int i=1;i<=n;++i)scanf("%lld%lld",&m[i],&a[i]),M*=m[i];
    for(int i=1;i<=n;++i){
        _M=M/m[i];
        exgcd(_M,m[i],g,x,y);
        t=x;
        ans=(ans+t*_M*a[i])%M;
    }
    printf("%lld",(ans+M)%M);
    return 0;
}
void exgcd(int a,int b,int &d,int &x,int &y){
    if(!b)d=a,x=1,y=0;
    else exgcd(b,a%b,d,y,x),y-=a/b*x;
}

P8178

非常妙的一道题,虽然做法中没有\(CRT\),但是有一条重要性质。

首先题意可以转化为\(A_if_0+B_i\equiv 0\pmod {p_i}\).

定义\(A_i=\prod_{i=1}^ka_i,B_i=\sum_{i=1}^k(\prod_{j=i+1}^ka_j)b_i\)

则有\(A_i=a_iA_{i-1},B_i=a_iB_{i-1}+b_i\)

则令\(B=B_i\mod p_i,A=A_i\mod p_i\)

这样,其中一个方程的解可以表示为\(f_0=(p_i-B)* A^{-1}\mod p_i\)(1)\(A^{-1}\)为逆元。

但是,这样的话,每个方程都可能有一个不同的解,怎么判断是否可以合并为一个解呢?

答案是中国剩余定理,对于\(x_i\not = x_j\)\(p_i=p_j\)时无解,因此只要不存在这种情况都有解。

具体地,由 (1) 可知,任何方程可以表示为\(f_0\equiv (p_i-B)* A^{-1} \pmod {p_i}\to x\equiv a_i\pmod{m_i}\),因此可以用\(CRT\)求解。

注意特判\(A=0\).

#include<iostream>
#include<cstdio>
#include<cstring>
#include<map>
#define int long long 
using namespace std;
const int N=2e3+10;
int T,k,_A,_B,A,B;
int a[N],b[N],p[N],vis[N];
map<int,int> M;
int kp(int x,int p,int P){
    if(p==0)return 1;
    if(p==1)return x;
    if(p&1)return x*kp(x*x%P,p>>1,P)%P;
    else return kp(x*x%P,p>>1,P)%P;
}
signed main(){
    scanf("%lld",&T);
    while(T--){
        bool flag=true;memset(vis,0,sizeof vis);
        M.clear();
        scanf("%lld",&k);
        for(int i=1;i<=k;++i)scanf("%lld",&a[i]);
        for(int i=1;i<=k;++i)scanf("%lld",&b[i]);
        for(int i=1;i<=k;++i)scanf("%lld",&p[i]);
        for(int i=1;i<=k;++i){
            A=1,B=0;
            for(int j=1;j<=i;++j) A=(A*a[j])%p[i],B=(B*a[j]%p[i]+b[j])%p[i];
            if(A==0 && B!=0){flag=false;break;}
            if(A==0 && B==0)continue;
            int tmp=(p[i]-B)*kp(A,p[i]-2,p[i])%p[i];
            if(M.find(p[i])==M.end())M[p[i]]=tmp;
            else if(M[p[i]]!=tmp){flag=false;break;}    
        }
        if(flag)printf("Yes\n");
        else printf("No\n");
    }
    return 0;
}