Fork me on GitHub

ZROI#987

ZROI#987

差分+简单数学即可.
首先有个性质:
两条链相交等价于其中一条链的$LCA$在另一条链上.
于是我们就对每一条链的$LCA$都加$1$.
最后查询每一条链的区间和即可.树剖实现.
但这样我们会算重复,就是说$(a,b)$两条链相交我们会算$(a,b)$一次,$(b,a)$一次.
也就是说我们算出的是有序数对.容斥掉即可.(没有公式,直接减掉一半即可.)
$Code:$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>
#include <queue>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
#define MEM(x,y) memset ( x , y , sizeof ( x ) )
#define rep(i,a,b) for (int i = (a) ; i <= (b) ; ++ i)
#define per(i,a,b) for (int i = (a) ; i >= (b) ; -- i)
#define pii pair < int , int >
#define X first
#define Y second
#define rint read<int>
#define int long long
#define pb push_back
#define ls ( rt << 1 )
#define rs ( rt << 1 | 1 )
#define mid ( ( l + r ) >> 1 )

using std::queue ;
using std::set ;
using std::pair ;
using std::max ;
using std::min ;
using std::priority_queue ;
using std::vector ;
using std::swap ;
using std::sort ;
using std::unique ;
using std::greater ;

template < class T >
inline T read () {
T x = 0 , f = 1 ; char ch = getchar () ;
while ( ch < '0' || ch > '9' ) {
if ( ch == '-' ) f = - 1 ;
ch = getchar () ;
}
while ( ch >= '0' && ch <= '9' ) {
x = ( x << 3 ) + ( x << 1 ) + ( ch - 48 ) ;
ch = getchar () ;
}
return f * x ;
}

const int N = 1e6 + 100 ;

vector < int > G[N] ;
int f[N] , deep[N] , ans , idx[N] , cnt ;
int n , m , p[N][2] , siz[N] , son[N] , top[N] ;

struct seg {
int left , right , data , tag ;
inline int size () { return right - left + 1 ; }
} t[N<<2] ;

inline void dfs (int cur , int anc , int dep) {
f[cur] = anc ; deep[cur] = dep ; siz[cur] = 1 ;
int maxson = - 1 ; for (int k : G[cur]) {
if ( k == anc ) continue ;
dfs ( k , cur , dep + 1 ) ; siz[cur] += siz[k] ;
if ( siz[k] > maxson ) maxson = siz[k] , son[cur] = k ;
}
return ;
}

inline void _dfs (int cur , int topf) {
top[cur] = topf ; idx[cur] = ++ cnt ;
if ( ! son[cur] ) return ; _dfs ( son[cur] , topf ) ;
for (int k : G[cur]) {
if ( k == son[cur] || k == f[cur] ) continue ;
_dfs ( k , k ) ;
}
return ;
}

inline void pushup (int rt) { t[rt].data = t[ls].data + t[rs].data ; return ; }

inline void build (int rt , int l , int r) {
t[rt].left = l ; t[rt].right = r ; t[rt].tag = 0 ;
if ( l == r ) { t[rt].data = 0 ; return ; }
build ( ls , l , mid ) ; build ( rs , mid + 1 , r ) ;
pushup ( rt ) ; return ;
}

inline void pushdown (int rt) {
t[ls].tag += t[rt].tag ; t[rs].tag += t[rt].tag ;
t[ls].data += t[ls].size () * t[rt].tag ;
t[rs].data += t[rs].size () * t[rt].tag ;
t[rt].tag = 0 ; return ;
}

inline void update (int rt , int ll , int rr , int val) {
int l = t[rt].left , r = t[rt].right ;
if ( l == ll && r == rr ) { t[rt].tag += val ; t[rt].data += val ; return ; }
if ( t[rt].tag ) pushdown ( rt ) ;
if ( rr <= mid ) update ( ls , ll , rr , val ) ;
else if ( ll > mid ) update ( rs , ll , rr , val ) ;
else { update ( ls , ll , mid , val ) ; update ( rs , mid + 1 , rr , val ) ; }
pushup ( rt ) ; return ;
}

inline int query (int rt , int ll , int rr) {
int l = t[rt].left , r = t[rt].right ;
if ( ll == l && r == rr ) return t[rt].data ;
if ( t[rt].tag ) pushdown ( rt ) ;
if ( rr <= mid ) return query ( ls , ll , rr ) ;
else if ( ll > mid ) return query ( rs , ll , rr ) ;
else return query ( ls , ll , mid ) + query ( rs , mid + 1 , rr ) ;
}

inline int qrange (int x , int y) {
int res = 0 ;
while ( top[x] != top[y] ) {
if ( deep[top[x]] < deep[top[y]] ) swap ( x , y ) ;
res += query ( 1 , idx[top[x]] , idx[x] ) ; x = f[top[x]] ;
}
if ( deep[x] > deep[y] ) swap ( x , y ) ;
return res + query ( 1 , idx[x] , idx[y] ) ;
}

inline int LCA (int x , int y) {
while ( top[x] != top[y] )
deep[top[x]] < deep[top[y]] ? y = f[top[y]] : x = f[top[x]] ;
return deep[x] < deep[y] ? x : y ;
}

signed main (int argc , char * argv[]) {
n = rint () ; m = rint () ;
rep ( i , 2 , n ) {
int u = rint () , v = rint () ;
G[u].pb ( v ) ; G[v].pb ( u ) ;
}
dfs ( 1 , 0 , 1 ) ; _dfs ( 1 , 1 ) ; build ( 1 , 1 , cnt ) ;
rep ( i , 1 , m ) {
p[i][0] = rint () ; p[i][1] = rint () ;
int t = LCA ( p[i][0] , p[i][1] ) ;
update ( 1 , idx[t] , idx[t] , 1 ) ;
}
rep ( i , 1 , m ) ans += ( qrange ( p[i][0] , p[i][1] ) - 1 ) ;
rep ( i , 1 , n ) {
int tmp = query ( 1 , idx[i] , idx[i] ) ;
ans -= tmp * ( tmp - 1 ) / 2 ;
}
printf ("%lld\n" , ans ) ;
return 0 ;
}