斜率优化学习笔记

顺便学了个李超树

Posted by ethan-zhou on June 7, 2021

问题形式

用来优化一类 dp,其转移方程满足:

\[dp_i=f_i(\max_{j<i}(k_j\times x_i+b_j))\]

(式子中取 min 也可,与 max 类似,不再赘述)

其中 $k_j,b_j$ 是只与 j 相关的数,$x_i$ 是只与 i 相关的数,$f_i$ 是只与 i 有关的函数。

斜率优化就是优化的求 $\max_{j<i}(k_j\times x_i+b_j)$ 的过程。

两种理解

截距法

将每个 j 对应到平面上一点 $(k_j,b_j)$,然后用一个斜率为 $-x_i$ 的直线从上往下移动,直到碰到一个点,此时 y 轴截距即为答案。

函数法

将 $g_j(x)=k_j\times x+b_j$ 看成只与 j 有关的的一次函数,原式等价于在 $x=x_i$ 处所有函数的极值。

感觉这个方法比较好理解,后文都以此方法为例。

如何维护

因为是将所有一次函数取 max,所以所有可能成为答案的直线必定是一个斜率递增下凸壳,在不同情况下,可采用不同的方法维护凸壳上的函数。

$k_i,x_i$ 都随 i 递增——单调队列

插入直线时,不断弹出队尾,直到队尾的直线在凸壳上没有被新直线完全覆盖,再将新队列加入队尾。

查询时,不断弹出队首,直到队列前两条直线的交点大于 $x_i$,此时队首的函数即为答案。

$k_i$ 随 i 递增,$x_i$ 不随 i 递增——单调队列+二分

插入方法同上,查询时不能直接弹队首了,应该在队列上二分,找到凸壳上包含 $x_i$ 的一段直线。

$k_i,x_i$ 都不随 i 递增——李超树

可以用李超树维护这个凸壳,后文有介绍,(听说也可以用平衡树,CDQ 分治维护,但不太好写)。

简单快捷李超树

李超树是一个支持以 $O(\log n)$ 复杂度插入一条直线(如果是插入线段,则复杂度为 $O(\log^2 n)$),并以同样复杂度查询某一个位置的最大值的线段树。

插入操作

李超树的每个节点上维护的是:在当前区间的中点,函数值最大的函数(下称最优函数),可以如下更新:

  • 考虑新插入的函数,与区间的原最优函数:
  • 将两者中较的存到这个节点上
  • 用两者中较的更新子区间:
    • 如果较劣的函数比较优的函数斜率,则用其递归更新右子区间(如下面图 1 所示,这条线段仍然有可能在右子区间里成为最优函数)
    • 否则用其递归更新左子区间(同理)

图 1

查询操作

求 x 处最大的函数值,只需将线段树上 x 所在区间,以及它祖先上所有区间的最优函数,在x处的函数值取 max 即可。

题目

AcWing 301. 任务安排2

费用提前计算,记 $t_i,c_i$ 的前缀和为 $st_i,sc_i$,列出转移方程:

\[dp_i=\min_{j<i}(\underbrace{-sc_j}_{k_j}\times \underbrace{st_i}_{x_i}+\underbrace{dp_j+s\times(sc_n-sc_j)}_{b_j})+st_i*sc_i\]

其中斜率与横坐标 $-sc_j,st_i$ 都单调,只需单调队列维护即可 但我是拿二分的代码过的

AcWing 302. 任务安排3

与上一题题意相同,区别在 $st_i$ 不单调,只需在单调队列上二分即可。

代码:

#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pi;
const ll INF=1e18;
const ll MXN=3e5+5;

inline long double inter(const pi &x,const pi &y){return x.fi==y.fi?(x.se>y.se?-INF:INF):(long double)(y.se-x.se)/(x.fi-y.fi);}

ll n,s,sc[MXN],st[MXN],dp[MXN];
pi q[MXN];ll ql=1,qr;
int main(){
	scanf("%lld%lld",&n,&s);
	for(int i=1;i<=n;i++){
		scanf("%lld%lld",st+i,sc+i);
		st[i]+=st[i-1],sc[i]+=sc[i-1];
	}
	q[++qr]=mp(0,s*sc[n]);
	for(int i=1;i<=n;i++){
		ll l=ql,r=qr;
		while(l<r){
			ll mid=(l+r)>>1;
			if(inter(q[mid],q[mid+1])>=st[i])r=mid;
			else l=mid+1;
		}
		dp[i]=q[l].fi*st[i]+q[l].se+st[i]*sc[i];
		pi curseg=mp(-sc[i],dp[i]+s*(sc[n]-sc[i]));
		while(ql<qr && inter(q[qr-1],q[qr])>=inter(q[qr],curseg))qr--;
		q[++qr]=curseg;
	}
	printf("%lld",dp[n]);

	return 0;
}

P3628 [APIO2010]特别行动队

直接推式子,方程为:

\[dp_i=(a s_i^2+b s_i+c) + \max_{j<i}(\underbrace{-2a s_j}_{k_j}\times\underbrace{s_i}_{x_i} +\underbrace{dp_j+a s_j^2-b s_j}_{b_j})\]

其中斜率与 x 坐标 $-2a s_i,s_i$ 都单调变化,所以单调队列维护即可。

代码:

#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pi;

const long double INF=1e18;
const long double EPS=1e-9;
const ll MXN=1e6+5;
ll n,a,b,c,ql=1,qr,dp[MXN];
pi q[MXN];

long double intersec(pi x,pi y){
	if(x.fi==y.fi)return y.se>x.se?-INF:INF;
	return (long double)(y.se-x.se)/(x.fi-y.fi);
}
int main(){
	scanf("%lld%lld%lld%lld",&n,&a,&b,&c);
	q[++qr]=mp(0,0);
	for(ll i=1,tmp,s=0;i<=n;i++){
		scanf("%lld",&tmp);s+=tmp;
		while(ql<qr && intersec(q[ql],q[ql+1])+EPS<s)ql++;
		dp[i]=s*q[ql].fi+q[ql].se+(a*s*s+b*s+c);
		pi cseg=mp(-2*a*s,dp[i]+a*s*s-b*s);
		while(ql<qr && intersec(q[qr-1],q[qr])>intersec(q[qr-1],cseg)+EPS)qr--;
		q[++qr]=cseg;
	}
	printf("%lld",dp[n]);

	return 0;
}

P3648 [APIO2014]序列分割

发现结果与切割顺序无关,不难推出 dp 式,其中斜率与 x 坐标单调增。

注意特判两条线平行的情况,返回正负无穷,至于具体返回哪个,因题而异。

代码:

#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pi;

const long double INF=1e18;
const ll MXN=1e5+5;
const ll MXK=205;

ll ql,qr;pi q[MXN];
long double inter(pi x,pi y){
	if(x.fi==y.fi)return y.se>x.se?-INF:-INF;
	return (long double)(y.se-x.se)/(x.fi-y.fi);
}
ll n,k,arr[MXN];
ll dp[MXK][MXN];
int main(){

	scanf("%lld%lld",&n,&k);
	for(int i=1;i<=n;i++){
		scanf("%lld",arr+i);
		arr[i]+=arr[i-1];
	}
	for(int i=1;i<=k;i++){
		q[qr=ql=1]=mp(0,0);
		for(int j=1;j<=n;j++){
			while(ql<qr && arr[j]>=inter(q[ql],q[ql+1]))ql++;
			dp[i][j]=arr[j]*q[ql].fi+q[ql].se;
			pi cseg=mp(arr[j],dp[i-1][j]-arr[j]*arr[j]);
			while(ql<qr && inter(q[qr-1],q[qr])>=inter(q[qr-1],cseg))qr--;
			q[++qr]=cseg;
		}
	}
	printf("%lld\n",dp[k][n]);
	for(int i=k,cur=n;i;i--){
		for(int j=1;j<cur;j++)
			if(dp[i-1][j]+arr[j]*(arr[cur]-arr[j])==dp[i][cur]){
				cur=j;
				break;
			}
		printf("%lld\n",cur);
	}
	return 0;

}

P5017 [NOIP2018 普及组] 摆渡车

dp 转移不难变成斜率优化的形式,可以线性解决。

还可以进一步利用两人到达时间相差超过 $2m$ 可以当成 $2m$ 的性质,缩短时间轴。

复杂度 $O(\min(nm,T))$

代码:

#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pi;
const ll MXN=505;
const ll MXT=1.1e5;
const ll INF=1e18;
ll n,m,ans=INF;
ll arr[MXN],s[MXT],si[MXT],dp[MXT];
ll ql=1,qr;
pi q[MXT];
inline long double inter(const pi &a,const pi &b){return a.fi==b.fi?(a.se>b.se?-INF:INF):(long double)(a.se-b.se)/(b.fi-a.fi);}
int main(){
	scanf("%lld%lld",&n,&m);
	for(ll i=1;i<=n;i++)scanf("%lld",arr+i);
	sort(arr+1,arr+1+n);
	for(ll i=n;i;i--)arr[i]=min(m<<1,arr[i]-arr[i-1]);
	for(ll i=1;i<=n;i++){
		arr[i]+=arr[i-1];
		s[arr[i]]++;
		si[arr[i]]+=arr[i];
	}

	for(ll i=1;i<=arr[n]+m;i++){
		s[i]+=s[i-1],si[i]+=si[i-1];
		if(i>=m){
			pi cseg=mp(-s[i-m],dp[i-m]+si[i-m]);
			while(ql<qr && inter(q[qr-1],q[qr])>=inter(q[qr],cseg))qr--;
			q[++qr]=cseg;
			while(ql<qr && i>=inter(q[ql],q[ql+1]))ql++;
			dp[i]=q[ql].fi*i+q[ql].se;
		}
		dp[i]+=s[i]*i-si[i];
		if(i>=arr[n])ans=min(ans,dp[i]);
	}
	printf("%lld",ans);
	return 0;
}

P4027 [NOI2007] 货币兑换 (李超树)

发现最优情况下每一次买入必定把所有钱都用了,每次卖出都必定把所有金券卖了。

于是不难推出转移方程为($dp_i$ 表示第 i 天,卖币之后,买币之前,最多拥有多少钱):

\[dp_i=max(dp_{i-1},\max(\underbrace{(B_i/A_i)}_{x_i}\times \underbrace{A_i dp_j/(A_j Rate_j+B_j)}_{k_j}+\underbrace{Rate_j A_i dp_j/(A_j Rate_j+B_j)}_{b_j}))\]

发现斜率与横坐标都不一定单调,只能用李超树维护。

又因为用李超树查询的横坐标是浮点数,不太好搞,可以提前将会查询到的 x 坐标离散化。

代码(甚短):

#include<bits/stdc++.h>
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef pair<double,double> pr;
const ll MXN=1e5+5;
const ll INF=1e18;
ll n;
double arr[MXN],brr[MXN],rate[MXN],dp[MXN],pool[MXN];


struct node{ll ls,rs;pr line;}t[MXN<<2];ll rt,nodec;
inline double f(pr line,ll x){return line.fi*pool[x]+line.se;}
inline double f(pr line,double x){return line.fi*x+line.se;}
inline ll nnode(){return t[++nodec].line=mp(0,-INF),nodec;}
inline void mod(ll &p,ll l,ll r,pr ml){
	ll mid=(l+r)>>1;
	if(!p)p=nnode();
	if(f(ml,mid)>f(t[p].line,mid))swap(ml,t[p].line);
	if(l==r)return;
	ml.fi<t[p].line.fi?mod(t[p].ls,l,mid,ml):mod(t[p].rs,mid+1,r,ml);
}
inline double que(ll p,ll l,ll r,double qx){
	if(!p || l==r)return f(t[p].line,qx);
	ll mid=(l+r)>>1;
	return max(f(t[p].line,qx),qx>pool[mid]?que(t[p].rs,mid+1,r,qx):que(t[p].ls,l,mid,qx));
}
int main(){
	scanf("%lld%lf",&n,&dp[0]);
	for(int i=1;i<=n;i++){
		scanf("%lf%lf%lf",arr+i,brr+i,rate+i);
		pool[i]=brr[i]/arr[i];
	}
	sort(pool+1,pool+1+n);
	for(int i=1;i<=n;i++){
		dp[i]=max(dp[i-1],arr[i]*que(rt,1,n,brr[i]/arr[i]));
		double k=dp[i]/(arr[i]*rate[i]+brr[i]);
		mod(rt,1,n,mp(k,k*rate[i]));
	}
	printf("%.3f",dp[n]);
	
	return 0;
}

作者:@ethan-enhe
本文为作者原创,转载请注明出处:本文链接