矩阵乘法
P3328
神题,线段树+矩阵乘法
1. 矩阵乘法预处理
首先,设\(f[i]=F[a[i]]\),而这个可以用\(3\times 3\)矩阵预处理出来。
具体地,设三元向量\([f_{k+2},f_{k+1},1]\),则有递推式
则
2. 线段树维护
有了\(f[i]\)数组,我们就可以用线段树维护了。
每个线段树节点建两个二维\(3\times 3\)数组,\(sum\)和\(data\),分别维护:
这样,每次遇到\(+1,-1\)操作时,可以直接用现有的值求出未知量。
同时,因为同一个区间对\(a_{i-1},a_{i+1}\)的影响不同,每个\([l,r]\)更改会对\([l+1,r+1]\)的\(a_{i-1}\)进行更改,而对\([l-1,r-1]\)的\(a_{i+1}\)进行更改。
所以我们每次更改两次,记录一个\(t\),表示是改\(a_{i-1}\)还是\(a_{i+1}\).
对于加操作,有\(f_{k+2}=f_{k+1}+a\times f_k+b\):
void add(int i,int t) {
int l=tre[i].l,r=tre[i].r;ll w=r-l+1;
for(int k=0; k<=1; ++k) _s(i,t,k)=_s(i,t,k+1);
_s(i,t,2)=(_s(i,t,1)+a*_s(i,t,0)%P+b*w%P)%P;
if(t==0) { //a[i-1]
for(int k=0; k<=2; ++k)
for(int j=0; j<=1; ++j) _d(i,j,k)=_d(i,j+1,k);
for(int k=0; k<=2; ++k) _d(i,2,k)=(_d(i,1,k)+_d(i,0,k)*a%P+b*_s(i,1,k)%P)%P;
} else {
for(int j=0; j<=2; ++j)
for(int k=0; k<=1; ++k) _d(i,j,k)=_d(i,j,k+1);
for(int j=0; j<=2; ++j) _d(i,j,2)=(_d(i,j,1)+a*_d(i,j,0)%P+b*_s(i,0,j)%P)%P;
}//a[i+1]
return;
}
注意线段树维护区间,所以\(b\)要乘以区间长\(r-l+1\)
对于减操作需要解方程求出\(f_k\)。
\(f_k=\left\{\begin{array}{r}\frac{f_{k+2}-f_{k+1}-b}{a} (a\not =0) \\ f_{k+1}-b (a=0) \end{array} \right.\)
则对于\(a\)特判,有:
void del(int i,int t) {
int l=tre[i].l,r=tre[i].r;ll w=r-l+1;
if(a==0){
for(int k=2;k>=1;--k) _s(i,t,k)=_s(i,t,k-1);
_s(i,t,0)=(_s(i,t,1)-b*w%P+P)%P;
if(t==0){//a[i-1]
for(int k=0;k<=2;++k)
for(int j=2;j>=1;--j) _d(i,j,k)=_d(i,j-1,k);
for(int k=0;k<=2;++k) _d(i,0,k)=(_d(i,1,k)-b*_s(i,1,k)%P+P)%P;
}else{//a[i+1]
for(int j=0;j<=2;++j)
for(int k=2;k>=1;--k) _d(i,j,k)=_d(i,j,k-1);
for(int j=0;j<=2;++j) _d(i,j,0)=(_d(i,j,1)-b*_s(i,0,j)%P+P)%P;
}
}else{
for(int k=2;k>=1;--k) _s(i,t,k)=_s(i,t,k-1);
_s(i,t,0)=(_s(i,t,2)-_s(i,t,1)+P-b*w%P+P)%P*inva%P;
if(t==0){//a[i-1]
for(int k=0;k<=2;++k)
for(int j=2;j>=1;--j) _d(i,j,k)=_d(i,j-1,k);
for(int k=0;k<=2;++k) _d(i,0,k)=(_d(i,2,k)-_d(i,1,k)+P-b*_s(i,1,k)%P+P)%P*inva%P;
}else{//a[i+1]
for(int j=0;j<=2;++j)
for(int k=2;k>=1;--k) _d(i,j,k)=_d(i,j,k-1);
for(int j=0;j<=2;++j) _d(i,j,0)=(_d(i,j,2)-_d(i,j,1)+P-b*_s(i,0,j)%P+P)%P*inva%P;
}
}
}
那么我们需要两个\(lazytag\),分别表示\(a_{i-1}\)和\(a_{i+1}\)的变化量。
\(pushup\)直接暴力合并,\(pushdown\)也直接计算即可。
\(change\)需要区分\(t\),\(query\)返回\(data[2][0]\)即可。
至此,我们完整的过了一遍主要流程。
~~卡常一直是我的痛,所以一下代码要吸氧才能过~~
#include<iostream>
#include<cstdio>
#include<cstring>
#include<ctime>
#define ls (i<<1)
#define rs (i<<1|1)
#define mid (l+r>>1)
#define _s(i,j,k) tre[i].sum[j][k]
#define _d(i,j,k) tre[i].data[j][k]
#define ll long long
using namespace std;
const int N=3e5+10,P=1e9+7;
struct ma {
ll f[3][3];
int n,m;
ma() {
memset(f,0,sizeof f);n=m=0;
}
} s,t;
int read1(){
int x=0;char ch=getchar();
while(ch<'0' || ch>'9') ch=getchar();
while(ch>='0' && ch<='9') x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x;
}
void write1(ll x){
if(x>9) write1(x/10);
putchar(x%10+'0');return;
}
ma operator *(ma x,ma y) {
ma z=ma();
int n=x.n,m=x.m,p=y.m;
z.n=n,z.m=p;
for(int i=0; i<n; ++i)
for(int j=0; j<p; ++j)
for(int k=0; k<m; ++k) z.f[i][j]=(z.f[i][j]+x.f[i][k]*y.f[k][j]%P)%P;
return z;
}
struct tree {
int l,r,tag[2];
ll sum[3][3],data[3][3];
} tre[N<<2];
int n,q,_l,_r,x,y;
ll inva,a,b;
ll f[N][3];
int A[N];
char ch[10],_c;
ll kp(ll x,int p) {
if(p==0) return 1;
if(p==1) return x%P;
if(p&1) return x*kp(x*x%P,p>>1)%P;
else return kp(x*x%P,p>>1)%P;
}
ma Kp(ma x,int p) {
if(p==1) return x;
if(p&1) return x*Kp(x*x,p>>1);
else return Kp(x*x,p>>1);
}
void initf() {
inva=kp(a,P-2)%P;
s.n=s.m=3;
s.f[0][0]=s.f[0][1]=s.f[2][2]=1;s.f[1][0]=a;s.f[2][0]=b;
t.n=1,t.m=3;
t.f[0][0]=2;t.f[0][1]=t.f[0][2]=1;
for(int i=1; i<=n; ++i) {
if(A[i]==1) {
f[i][0]=-P,f[i][1]=1,f[i][2]=2;
} else if(A[i]==2) {
f[i][0]=1,f[i][1]=2;
f[i][2]=(f[i][1]+a*f[i][0]%P+b)%P;
} else {
ma tmp=t*Kp(s,A[i]-2);
f[i][0]=tmp.f[0][1];
f[i][1]=tmp.f[0][0];
f[i][2]=(f[i][1]+a*f[i][0]%P+b)%P;
}
}
}
void pushup(int i) {
for(int j=0; j<=1; ++j)
for(int k=0; k<=2; ++k) _s(i,j,k)=(_s(ls,j,k)+_s(rs,j,k))%P;
for(int j=0; j<=2; ++j)
for(int k=0; k<=2; ++k) _d(i,j,k)=(_d(ls,j,k)+_d(rs,j,k))%P;
}
void build(int i,int l,int r) {
tre[i].l=l,tre[i].r=r;
if(l==r) {
for(int k=0; k<=2; ++k) _s(i,0,k)=f[l-1][k];
for(int k=0; k<=2; ++k) _s(i,1,k)=f[l+1][k];
for(int j=0; j<=2; ++j)
for(int k=0; k<=2; ++k) _d(i,j,k)=_s(i,0,j)*_s(i,1,k)%P;
return;
}
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(i);return;
}
void add(int i,int t) {
int l=tre[i].l,r=tre[i].r;ll w=r-l+1;
for(int k=0; k<=1; ++k) _s(i,t,k)=_s(i,t,k+1);
_s(i,t,2)=(_s(i,t,1)+a*_s(i,t,0)%P+b*w%P)%P;
if(t==0) { //a[i-1]
for(int k=0; k<=2; ++k)
for(int j=0; j<=1; ++j) _d(i,j,k)=_d(i,j+1,k);
for(int k=0; k<=2; ++k) _d(i,2,k)=(_d(i,1,k)+_d(i,0,k)*a%P+b*_s(i,1,k)%P)%P;
} else {
for(int j=0; j<=2; ++j)
for(int k=0; k<=1; ++k) _d(i,j,k)=_d(i,j,k+1);
for(int j=0; j<=2; ++j) _d(i,j,2)=(_d(i,j,1)+a*_d(i,j,0)%P+b*_s(i,0,j)%P)%P;
}//a[i+1]
return;
}
void del(int i,int t) {
int l=tre[i].l,r=tre[i].r;ll w=r-l+1;
if(a==0){
for(int k=2;k>=1;--k) _s(i,t,k)=_s(i,t,k-1);
_s(i,t,0)=(_s(i,t,1)-b*w%P+P)%P;
if(t==0){//a[i-1]
for(int k=0;k<=2;++k)
for(int j=2;j>=1;--j) _d(i,j,k)=_d(i,j-1,k);
for(int k=0;k<=2;++k) _d(i,0,k)=(_d(i,1,k)-b*_s(i,1,k)%P+P)%P;
}else{//a[i+1]
for(int j=0;j<=2;++j)
for(int k=2;k>=1;--k) _d(i,j,k)=_d(i,j,k-1);
for(int j=0;j<=2;++j) _d(i,j,0)=(_d(i,j,1)-b*_s(i,0,j)%P+P)%P;
}
}else{
for(int k=2;k>=1;--k) _s(i,t,k)=_s(i,t,k-1);
_s(i,t,0)=(_s(i,t,2)-_s(i,t,1)+P-b*w%P+P)%P*inva%P;
if(t==0){//a[i-1]
for(int k=0;k<=2;++k)
for(int j=2;j>=1;--j) _d(i,j,k)=_d(i,j-1,k);
for(int k=0;k<=2;++k) _d(i,0,k)=(_d(i,2,k)-_d(i,1,k)+P-b*_s(i,1,k)%P+P)%P*inva%P;
}else{//a[i+1]
for(int j=0;j<=2;++j)
for(int k=2;k>=1;--k) _d(i,j,k)=_d(i,j,k-1);
for(int j=0;j<=2;++j) _d(i,j,0)=(_d(i,j,2)-_d(i,j,1)+P-b*_s(i,0,j)%P+P)%P*inva%P;
}
}
}
void calc(int i,int t,int c) {
if(!c) return;
if(c>0) for(int k=1; k<=c; ++k) add(i,t);
else for(int k=1; k<=-c; ++k) del(i,t);
}
void pushdown(int i) {
int l=tre[i].l,r=tre[i].r,c=tre[i].tag[0],d=tre[i].tag[1];
if(!c && !d) return;
tre[ls].tag[0]+=c;tre[ls].tag[1]+=d;
tre[rs].tag[0]+=c;tre[rs].tag[1]+=d;
calc(ls,0,c);calc(ls,1,d);
calc(rs,0,c);calc(rs,1,d);
tre[i].tag[0]=tre[i].tag[1]=0;
}
void change(int i,int el,int er,int t,int c) {
int l=tre[i].l,r=tre[i].r;
if(el<=l && r<=er) {
tre[i].tag[t]+=c;
calc(i,t,c);
return;
}
pushdown(i);
if(el<=mid) change(ls,el,er,t,c);
if(er>mid) change(rs,el,er,t,c);
pushup(i);
return;
}
ll query(int i,int el,int er) {
int l=tre[i].l,r=tre[i].r;
if(el<=l && r<=er) return _d(i,2,0);
ll ans=0;
pushdown(i);
if(el<=mid) ans=(ans+query(ls,el,er))%P;
if(er>mid) ans=(ans+query(rs,el,er))%P;
return ans;
}
int main() {
n=read1(),q=read1(),a=(long long)read1(),b=(long long)read1();
for(int i=1; i<=n; ++i) A[i]=read1();
initf();
build(1,1,n);
for(int i=1; i<=q; ++i) {
scanf("%s",ch);_l=read1(),_r=read1();
if(ch[0]=='q') {
if(_l+1<=_r-1) write1(query(1,_l+1,_r-1)),putchar('\n');
else putchar('0'),putchar('\n');
} else if(ch[0]=='p') {
x=_l+1,y=_r+1<n?_r+1:n;change(1,x,y,0,1);
x=_l-1>1?_l-1:1,y=_r-1;change(1,x,y,1,1);
} else { //m
x=_l+1,y=_r+1<n?_r+1:n;change(1,x,y,0,-1);
x=_l-1>1?_l-1:1,y=_r-1;change(1,x,y,1,-1);
}
}
return 0;
}
~~学流程用了一天半,写代码不到两小时写了两百行,所以思路清晰是非常重要的。~~