Skip to content

矩阵乘法

P3328

神题,线段树+矩阵乘法

1. 矩阵乘法预处理

首先,设\(f[i]=F[a[i]]\),而这个可以用\(3\times 3\)矩阵预处理出来。

具体地,设三元向量\([f_{k+2},f_{k+1},1]\),则有递推式

\[[f_{k+2},f_{k+1},1]=[f_{k+1},f_k,1]\times\left[ \begin{array}{lll}1 & 1& 0\\a & 0 & 0\\b& 0 & 1 \end{array} \right]\]

\[[f_{k+2},f_{k+1},1]=[2,1,1]\times\left[ \begin{array}{lll}1 & 1& 0\\a & 0 & 0\\b& 0 & 1 \end{array} \right]^k\]

2. 线段树维护

有了\(f[i]\)数组,我们就可以用线段树维护了。

每个线段树节点建两个二维\(3\times 3\)数组,\(sum\)\(data\),分别维护:

\[sum: \left[\begin{array}{l}f_{a_{i-1}-1} & f_{a_{i-1}}& f_{a_{i-1}+1}\\ f_{a_{i+1}-1}& f_{a_{i+1}}& f_{a_{i+1}+1}\end{array}\right] \\ data:\left[\begin{array}{l}f_{a_{i-1}-1}\times f_{a_{i+1}-1}& f_{a_{i-1}-1}\times f_{a_{i+1}}& f_{a_{i-1}-1}\times f_{a_{i+1}+1}\\ f_{a_{i-1}}\times f_{a_{i+1}-1}& f_{a_{i}-1}\times f_{a_{i+1}}& f_{a_{i-1}}\times f_{a_{i+1}+1}\\ f_{a_{i-1}+1}\times f_{a_{i+1}-1}& f_{a_{i-1}+1}\times f_{a_{i+1}}& f_{a_{i-1}+1}\times f_{a_{i+1}+1}\\ \end{array}\right] \]

这样,每次遇到\(+1,-1\)操作时,可以直接用现有的值求出未知量。

同时,因为同一个区间对\(a_{i-1},a_{i+1}\)的影响不同,每个\([l,r]\)更改会对\([l+1,r+1]\)\(a_{i-1}\)进行更改,而对\([l-1,r-1]\)\(a_{i+1}\)进行更改。

所以我们每次更改两次,记录一个\(t\),表示是改\(a_{i-1}\)还是\(a_{i+1}\).

对于加操作,有\(f_{k+2}=f_{k+1}+a\times f_k+b\):

void add(int i,int t) {
    int l=tre[i].l,r=tre[i].r;ll w=r-l+1;
    for(int k=0; k<=1; ++k) _s(i,t,k)=_s(i,t,k+1);
    _s(i,t,2)=(_s(i,t,1)+a*_s(i,t,0)%P+b*w%P)%P;
    if(t==0) { //a[i-1]
        for(int k=0; k<=2; ++k)
            for(int j=0; j<=1; ++j) _d(i,j,k)=_d(i,j+1,k);
        for(int k=0; k<=2; ++k) _d(i,2,k)=(_d(i,1,k)+_d(i,0,k)*a%P+b*_s(i,1,k)%P)%P;
    } else {
        for(int j=0; j<=2; ++j)
            for(int k=0; k<=1; ++k) _d(i,j,k)=_d(i,j,k+1);
        for(int j=0; j<=2; ++j) _d(i,j,2)=(_d(i,j,1)+a*_d(i,j,0)%P+b*_s(i,0,j)%P)%P;
    }//a[i+1]
    return;
}

注意线段树维护区间,所以\(b\)要乘以区间长\(r-l+1\)

对于减操作需要解方程求出\(f_k\)

\(f_k=\left\{\begin{array}{r}\frac{f_{k+2}-f_{k+1}-b}{a} (a\not =0) \\ f_{k+1}-b (a=0) \end{array} \right.\)

则对于\(a\)特判,有:

void del(int i,int t) {
    int l=tre[i].l,r=tre[i].r;ll w=r-l+1;
    if(a==0){
        for(int k=2;k>=1;--k) _s(i,t,k)=_s(i,t,k-1);
        _s(i,t,0)=(_s(i,t,1)-b*w%P+P)%P;
        if(t==0){//a[i-1]
            for(int k=0;k<=2;++k)
                for(int j=2;j>=1;--j) _d(i,j,k)=_d(i,j-1,k);
            for(int k=0;k<=2;++k) _d(i,0,k)=(_d(i,1,k)-b*_s(i,1,k)%P+P)%P;
        }else{//a[i+1]
            for(int j=0;j<=2;++j)
                for(int k=2;k>=1;--k) _d(i,j,k)=_d(i,j,k-1);
            for(int j=0;j<=2;++j) _d(i,j,0)=(_d(i,j,1)-b*_s(i,0,j)%P+P)%P;
        }
    }else{
        for(int k=2;k>=1;--k) _s(i,t,k)=_s(i,t,k-1);
        _s(i,t,0)=(_s(i,t,2)-_s(i,t,1)+P-b*w%P+P)%P*inva%P;
        if(t==0){//a[i-1]
            for(int k=0;k<=2;++k)
                for(int j=2;j>=1;--j) _d(i,j,k)=_d(i,j-1,k);
            for(int k=0;k<=2;++k) _d(i,0,k)=(_d(i,2,k)-_d(i,1,k)+P-b*_s(i,1,k)%P+P)%P*inva%P;
        }else{//a[i+1]
            for(int j=0;j<=2;++j)
                for(int k=2;k>=1;--k) _d(i,j,k)=_d(i,j,k-1);
            for(int j=0;j<=2;++j) _d(i,j,0)=(_d(i,j,2)-_d(i,j,1)+P-b*_s(i,0,j)%P+P)%P*inva%P;
        }
    }
}

那么我们需要两个\(lazytag\),分别表示\(a_{i-1}\)\(a_{i+1}\)的变化量。

\(pushup\)直接暴力合并,\(pushdown\)也直接计算即可。

\(change\)需要区分\(t\)\(query\)返回\(data[2][0]\)即可。

至此,我们完整的过了一遍主要流程。

~~卡常一直是我的痛,所以一下代码要吸氧才能过~~

#include<iostream>
#include<cstdio>
#include<cstring>
#include<ctime>
#define ls (i<<1)
#define rs (i<<1|1)
#define mid (l+r>>1)
#define _s(i,j,k) tre[i].sum[j][k]
#define _d(i,j,k) tre[i].data[j][k]
#define ll long long 
using namespace std;
const int N=3e5+10,P=1e9+7;
struct ma {
    ll f[3][3];
    int n,m;
    ma() {
        memset(f,0,sizeof f);n=m=0;
    }
} s,t;
int read1(){
    int x=0;char ch=getchar();
    while(ch<'0' || ch>'9') ch=getchar();
    while(ch>='0' && ch<='9') x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
    return x;
}
void write1(ll x){
    if(x>9) write1(x/10);
    putchar(x%10+'0');return;
}
ma operator *(ma x,ma y) {
    ma z=ma();
    int n=x.n,m=x.m,p=y.m;
    z.n=n,z.m=p;
    for(int i=0; i<n; ++i)
        for(int j=0; j<p; ++j)
            for(int k=0; k<m; ++k) z.f[i][j]=(z.f[i][j]+x.f[i][k]*y.f[k][j]%P)%P;
    return z;
}
struct tree {
    int l,r,tag[2];
    ll sum[3][3],data[3][3];
} tre[N<<2];
int n,q,_l,_r,x,y;
ll inva,a,b;
ll f[N][3];
int A[N];
char ch[10],_c;
ll kp(ll x,int p) {
    if(p==0) return 1;
    if(p==1) return x%P;
    if(p&1) return x*kp(x*x%P,p>>1)%P;
    else return kp(x*x%P,p>>1)%P;
}
ma Kp(ma x,int p) {
    if(p==1) return x;
    if(p&1) return x*Kp(x*x,p>>1);
    else return Kp(x*x,p>>1);
}
void initf() {
    inva=kp(a,P-2)%P;
    s.n=s.m=3;
    s.f[0][0]=s.f[0][1]=s.f[2][2]=1;s.f[1][0]=a;s.f[2][0]=b;
    t.n=1,t.m=3;
    t.f[0][0]=2;t.f[0][1]=t.f[0][2]=1;
    for(int i=1; i<=n; ++i) {
        if(A[i]==1) {
            f[i][0]=-P,f[i][1]=1,f[i][2]=2;
        } else if(A[i]==2) {
            f[i][0]=1,f[i][1]=2;
            f[i][2]=(f[i][1]+a*f[i][0]%P+b)%P;
        } else {
            ma tmp=t*Kp(s,A[i]-2);
            f[i][0]=tmp.f[0][1];
            f[i][1]=tmp.f[0][0];
            f[i][2]=(f[i][1]+a*f[i][0]%P+b)%P;
        }
    }
}
void pushup(int i) {
    for(int j=0; j<=1; ++j)
        for(int k=0; k<=2; ++k) _s(i,j,k)=(_s(ls,j,k)+_s(rs,j,k))%P;
    for(int j=0; j<=2; ++j)
        for(int k=0; k<=2; ++k) _d(i,j,k)=(_d(ls,j,k)+_d(rs,j,k))%P;
}
void build(int i,int l,int r) {
    tre[i].l=l,tre[i].r=r;
    if(l==r) {
        for(int k=0; k<=2; ++k) _s(i,0,k)=f[l-1][k];
        for(int k=0; k<=2; ++k) _s(i,1,k)=f[l+1][k];
        for(int j=0; j<=2; ++j)
            for(int k=0; k<=2; ++k) _d(i,j,k)=_s(i,0,j)*_s(i,1,k)%P;
        return;
    }
    int mid=l+r>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    pushup(i);return;
}
void add(int i,int t) {
    int l=tre[i].l,r=tre[i].r;ll w=r-l+1;
    for(int k=0; k<=1; ++k) _s(i,t,k)=_s(i,t,k+1);
    _s(i,t,2)=(_s(i,t,1)+a*_s(i,t,0)%P+b*w%P)%P;
    if(t==0) { //a[i-1]
        for(int k=0; k<=2; ++k)
            for(int j=0; j<=1; ++j) _d(i,j,k)=_d(i,j+1,k);
        for(int k=0; k<=2; ++k) _d(i,2,k)=(_d(i,1,k)+_d(i,0,k)*a%P+b*_s(i,1,k)%P)%P;
    } else {
        for(int j=0; j<=2; ++j)
            for(int k=0; k<=1; ++k) _d(i,j,k)=_d(i,j,k+1);
        for(int j=0; j<=2; ++j) _d(i,j,2)=(_d(i,j,1)+a*_d(i,j,0)%P+b*_s(i,0,j)%P)%P;
    }//a[i+1]
    return;
}
void del(int i,int t) {
    int l=tre[i].l,r=tre[i].r;ll w=r-l+1;
    if(a==0){
        for(int k=2;k>=1;--k) _s(i,t,k)=_s(i,t,k-1);
        _s(i,t,0)=(_s(i,t,1)-b*w%P+P)%P;
        if(t==0){//a[i-1]
            for(int k=0;k<=2;++k)
                for(int j=2;j>=1;--j) _d(i,j,k)=_d(i,j-1,k);
            for(int k=0;k<=2;++k) _d(i,0,k)=(_d(i,1,k)-b*_s(i,1,k)%P+P)%P;
        }else{//a[i+1]
            for(int j=0;j<=2;++j)
                for(int k=2;k>=1;--k) _d(i,j,k)=_d(i,j,k-1);
            for(int j=0;j<=2;++j) _d(i,j,0)=(_d(i,j,1)-b*_s(i,0,j)%P+P)%P;
        }
    }else{
        for(int k=2;k>=1;--k) _s(i,t,k)=_s(i,t,k-1);
        _s(i,t,0)=(_s(i,t,2)-_s(i,t,1)+P-b*w%P+P)%P*inva%P;
        if(t==0){//a[i-1]
            for(int k=0;k<=2;++k)
                for(int j=2;j>=1;--j) _d(i,j,k)=_d(i,j-1,k);
            for(int k=0;k<=2;++k) _d(i,0,k)=(_d(i,2,k)-_d(i,1,k)+P-b*_s(i,1,k)%P+P)%P*inva%P;
        }else{//a[i+1]
            for(int j=0;j<=2;++j)
                for(int k=2;k>=1;--k) _d(i,j,k)=_d(i,j,k-1);
            for(int j=0;j<=2;++j) _d(i,j,0)=(_d(i,j,2)-_d(i,j,1)+P-b*_s(i,0,j)%P+P)%P*inva%P;
        }
    }
}
void calc(int i,int t,int c) {
    if(!c) return;
    if(c>0) for(int k=1; k<=c; ++k) add(i,t);
    else for(int k=1; k<=-c; ++k) del(i,t);
}
void pushdown(int i) {
    int l=tre[i].l,r=tre[i].r,c=tre[i].tag[0],d=tre[i].tag[1];
    if(!c && !d) return;
    tre[ls].tag[0]+=c;tre[ls].tag[1]+=d;
    tre[rs].tag[0]+=c;tre[rs].tag[1]+=d;
    calc(ls,0,c);calc(ls,1,d);
    calc(rs,0,c);calc(rs,1,d);
    tre[i].tag[0]=tre[i].tag[1]=0;
}
void change(int i,int el,int er,int t,int c) {
    int l=tre[i].l,r=tre[i].r;
    if(el<=l && r<=er) {
        tre[i].tag[t]+=c;
        calc(i,t,c);
        return;
    }
    pushdown(i);
    if(el<=mid) change(ls,el,er,t,c);
    if(er>mid) change(rs,el,er,t,c);
    pushup(i);
    return;
}
ll query(int i,int el,int er) {
    int l=tre[i].l,r=tre[i].r;
    if(el<=l && r<=er) return _d(i,2,0);
    ll ans=0;
    pushdown(i);
    if(el<=mid) ans=(ans+query(ls,el,er))%P;
    if(er>mid) ans=(ans+query(rs,el,er))%P;
    return ans;
}
int main() {
    n=read1(),q=read1(),a=(long long)read1(),b=(long long)read1();
    for(int i=1; i<=n; ++i) A[i]=read1();
    initf();
    build(1,1,n);
    for(int i=1; i<=q; ++i) {
        scanf("%s",ch);_l=read1(),_r=read1();
        if(ch[0]=='q') {
            if(_l+1<=_r-1) write1(query(1,_l+1,_r-1)),putchar('\n');
            else putchar('0'),putchar('\n');
        } else if(ch[0]=='p') {
            x=_l+1,y=_r+1<n?_r+1:n;change(1,x,y,0,1);
            x=_l-1>1?_l-1:1,y=_r-1;change(1,x,y,1,1);
        } else { //m
            x=_l+1,y=_r+1<n?_r+1:n;change(1,x,y,0,-1);
            x=_l-1>1?_l-1:1,y=_r-1;change(1,x,y,1,-1);
        }
    }
    return 0;
}

~~学流程用了一天半,写代码不到两小时写了两百行,所以思路清晰是非常重要的。~~