Removing Stones

题目链接

题意

有$n$堆石子,每次可以拿走不同的两堆石子各一个,如果石子总数是奇数,可以开始先从最少的堆那一个,这样下去如果最后没拿完就输了。
现在问有多少个区间满足,在这些区间的石子中玩游戏,最后可以赢。

思路

考虑怎么才会赢,也就是对于任一一堆石子,满足它的数量小于等于其他石子数量之和。
其实只要最大的那一堆满足就可以了。
那么很容易想到枚举哪一堆石子是最大的,可以得到它的起作用区间$(l,r)$,设第$k$堆石子数量最多。
那么我们分治$(l, k-1)$与$(k+1,r)$,现在考虑一下跨过$k$的长度大于1的区间。

这里有一个策略:去找$r-k$与$k-l$较小的那边去枚举,另一侧去二分,这样可以保证复杂度为两个

这里还有一个更优秀的做法

(没看懂 待补)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#define N 300010
#define INF 0x3f3f3f3f
#define eps 1e-10
// #define pi 3.141592653589793
#define mod 97
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x) memset(x,0,sizeof x)
#define sc(x) scanf("%dd",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/hash_policy.hpp>
using namespace __gnu_pbds;

int a[N],lg[N];
LL s[N],ans;
int f[N][19];

inline int idx(int l,int r){
int t=lg[r-l];
if (a[f[l][t]]>a[f[r-(1<<t)+1][t]])
return f[l][t];else
return f[r-(1<<t)+1][t];
}

void work(int l,int r){
if (l>=r) return;
if (l+1==r){
if (a[l]==a[r]) ans++;
return;
}
int k=idx(l,r),t;LL M=a[k],tm=0,la;
if (k-l>r-k){
for(int i=k;i<=r;i++){
tm+=a[i]; la=M*2-tm;
if (s[i]-s[l-1]<M*2) continue;
if(la<=0) t=k;else {
t=lb(s+l,s+k+1,s[k-1]-la)-s;
if (la==s[k-1]-s[t]) t++;
}
ans+=t-l+1;
}
}else{
for(int i=k;i>=l;i--){
tm+=a[i]; la=M*2-tm;
if (s[r]-s[i-1]<M*2) continue;
if (la<=0) t=k;else
t=lb(s+k,s+r+1,s[k]+la)-s;
ans+=r-t+1;
}
}
work(l,k-1);
work(k+1,r);
}

int main(){
for(int i=1;(1<<i)<N;i++) lg[1<<i]=1;
for(int i=1;i<N;i++) lg[i]+=lg[i-1];
int T;
sc(T);
while(T--){
int n;
sc(n);
for(int i=1;i<=n;i++) sc(a[i]),s[i]=s[i-1]+a[i],f[i][0]=i;

for(int j=1;(1<<j)<=n;j++)
for(int i=1;i<=n;i++) if (i+(1<<j)-1<=n)
if (a[f[i][j-1]]>a[f[i+(1<<j-1)][j-1]])
f[i][j]=f[i][j-1];else
f[i][j]=f[i+(1<<j-1)][j-1];

ans=0;
work(1,n);
printf("%lld\n",ans);
}
}

0%