需要一点数学基础的题目叭,我因为比较菜,观察能力不够,所以卡了半个小时叭.
我们化一化原式:
$$(x^2+y)^2\equiv (x^2-y)^2\pmod p$$
$$x^4+2x^2y+y^2\equiv x^4-2x^2y+y^2\pmod p$$
$$2x^2y\equiv-2x^2y\pmod p$$
$$4x^2y\equiv 1\pmod p$$
唉唉唉!$y$不就是$4x^2$在模$p$意义下的逆元吗?
题目里又保证$p$一定是质数,所以逆元唯一.
虽然原序列中的元素保证两两不同,但极可能会有属于同一剩余类的元素.
而逆元其实也是一个剩余类,所以我们要取原序列的所有剩余类进行统计.
这样就做完了叭?
不,没有.
你还要考虑不存在逆元的情况.
虽然模数是质数,对于一般的数字一定存在逆元,但有一个剩余类是例外:$0$.
显然,$0$是不应该有逆元的.
这样才是真正做完了这题.
$Code:$1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
using std::queue ;
using std::set ;
using std::pair ;
using std::max ;
using std::min ;
using std::priority_queue ;
using std::vector ;
using std::swap ;
using std::sort ;
using std::unique ;
using std::greater ;
template < class T >
inline T read () {
T x = 0 , f = 1 ; char ch = getchar () ;
while ( ch < '0' || ch > '9' ) {
if ( ch == '-' ) f = - 1 ;
ch = getchar () ;
}
while ( ch >= '0' && ch <= '9' ) {
x = ( x << 3 ) + ( x << 1 ) + ( ch - 48 ) ;
ch = getchar () ;
}
return f * x ;
}
const int N = 1e5 + 100 ;
std::map < int , int > mk ;
int n , ans , mod , v[N] ;
inline int quick (int a , int p) {
int res = 1 ;
while ( p ) {
if ( p & 1 ) res = ( res * a ) % mod ;
a = a * a % mod ; p >>= 1 ;
}
return res % mod ;
}
inline int check (int ll , int rr , int x) {
int l = ll , r = rr ;
while ( l <= r ) {
int mid = ( l + r ) >> 1 ;
if ( v[mid] % mod == x ) return mid ;
if ( v[mid] % mod < x ) l = mid + 1 ;
if ( v[mid] % mod > x ) r = mid - 1 ;
}
return - 1 ;
}
signed main (int argc , char * argv[]) {
n = rint () ; mod = rint () ;
rep ( i , 1 , n ) v[i] = rint () ;
rep ( i , 1 , n ) { v[i] %= mod ; ++ mk[v[i]] ; }
sort ( v + 1 , v + n + 1 ) ;
for (int i = 1 ; i <= n ; ++ i) {
if ( ! v[i] ) continue ;
int tmp = ( v[i] * v[i] << 2 ) % mod ;
int inv = quick ( tmp , mod - 2 ) % mod ;
int pos = check ( 1 , n , inv ) ;
if ( pos == - 1 || pos == i ) continue ;
ans += mk[v[pos]] ;
}
printf ("%lld\n" , ans ) ;
return 0 ;
}