Sequence

Time limit:3000ms Memory limit:262144kB

链接

http://acm.hdu.edu.cn/showproblem.php?pid=6589

题意

给出一个长度为n的数列a,然后进行m次操作,每次操作后的新数列替换掉原来的数列。

思路

题解

代码

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
using namespace std;

typedef long long ll;

const ll mod = 998244353;
const int g = 3;
const int maxn = 2e6+5;

void exgcd (ll a, ll b, ll &x, ll &y)
{
    if (b == 0) {
        x = 1;
        y = 0;
        return ;
    }
    ll x0, y0;
    exgcd(b, a%b, x0, y0);
    x = y0;
    y = x0-1ll*a/b*y0;
}

ll inv (ll k)
{
    ll x, y;
    exgcd(k, mod, x, y);
    x = (x%mod+mod)%mod;
    return x;
}

ll qm (ll a, ll b)
{
    if (b < 0) {
        b = -b;
        a = inv(a);
    }
    ll res = 1;
    a %= mod;
    while (b) {
        if (b&1) {
            res = res*a%mod;
        }
        a = a*a%mod;
        b >>= 1;
    }
    return res;
}

ll rev[maxn];

void get_rev (ll bit)
{
    for (ll i = 0; i < (1<<bit); ++i) {
        rev[i] = 0;
    }
    for (ll i = 0; i < (1<<bit); ++i) {
        rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit-1));
    }
}

void ntt (ll *ar, ll n, ll dft)
{
    for (ll i = 0; i < n; ++i) {
        if (i < rev[i]) {
            swap(ar[i], ar[rev[i]]);
        }
    }
    for (ll step = 1; step < n; step <<= 1) {
        ll wn;
        wn = qm(g, dft*(mod-1)/(step*2));
        for (ll j = 0; j < n; j += (step<<1)) {
            ll wnk = 1;
            for (ll k = j; k < j+step; ++k) {
                ll x = ar[k]%mod, y = (wnk*ar[k+step])%mod;
                ar[k] = (x+y)%mod;
                ar[k+step] = ((x-y)%mod+mod)%mod;
                wnk = (wnk*wn)%mod;
            }
        }
    }
    if (dft == -1) {
        ll nI = inv(n);
        for (ll i = 0; i < n; ++i) {
            ar[i] = ar[i]*nI%mod;
        }
    }
}

void convolution (ll *ar, ll *br, ll n)
{
    ll bit, s = 2;
    for (bit = 1; (1<<bit) < 2*n-1; ++bit) {
        s <<= 1;
    }
    get_rev(bit);
    for (int i = n; i < s; ++i) {
        ar[i] = br[i] = 0;
    }
    ntt(ar, s, 1);
    ntt(br, s, 1);
    for (ll i = 0; i < s; ++i) {
        ar[i] = ar[i]*br[i]%mod;
    }
    ntt(ar, s, -1);
}

ll jc[maxn];
ll ijc[maxn];

void init ()
{
    ijc[0] = jc[0] = 1;
    for (int i = 1; i < maxn; ++i) {
        jc[i] = 1ll*i*jc[i-1]%mod;
        ijc[i] = inv(jc[i]);
    }
}

ll C (int n, int m) {
    return jc[n]*ijc[m]%mod*ijc[n-m]%mod;
}

ll tar[maxn];
ll tbr[maxn];

void get (ll *pr, int n, int m)
{
    for (int i = 0; i < n; ++i) {
        tar[i] = pr[i];
        tbr[i] = C(m+i-1, i);
    }
    convolution(tar, tbr, n);
    for (int i = 0; i < n; ++i) {
        pr[i] = tar[i];
    }
}

int n, m;
ll ar[maxn];
ll ar1[maxn];
ll ar2[maxn];
ll ar3[maxn];
int br[4];

int main ()
{
    int t;
    init();
    scanf("%d", &t);
    while (t--) {
        br[1] = br[2] = br[3] = 0;
        scanf("%d %d", &n, &m);
        for (int i = 0; i < n; ++i) {
            scanf("%lld", &ar[i]);
        }
        for (int i = 1; i <= m; ++i) {
            int tmp;
            scanf("%d", &tmp);
            ++br[tmp];
        }
        if (br[1] != 0) {
            get(ar, n, br[1]);
        }
        if (br[2] != 0) {
            int ln1 = 0, ln2 = 0;
            for (int i = 0; i < n; ++i) {
                if (i % 2 == 0) {
                    ar1[ln1++] = ar[i];
                }
                if (i % 2 == 1) {
                    ar2[ln2++] = ar[i];
                }
            }
            get(ar1, ln1, br[2]);
            get(ar2, ln2, br[2]);
            for (int i = 0; i < n; ++i) {
                if (i % 2 == 0) {
                    ar[i] = ar1[i/2];
                }
                if (i % 2 == 1) {
                    ar[i] = ar2[i/2];
                }
            }
        }
        if (br[3] != 0) {
            int ln1 = 0, ln2 = 0, ln3 = 0;
            for (int i = 0; i < n; ++i) {
                if (i % 3 == 0) {
                    ar1[ln1++] = ar[i];
                }
                if (i % 3 == 1) {
                    ar2[ln2++] = ar[i];
                }
                if (i % 3 == 2) {
                    ar3[ln3++] = ar[i];
                }
            }
            get(ar1, ln1, br[3]);
            get(ar2, ln2, br[3]);
            get(ar3, ln3, br[3]);
            for (int i = 0; i < n; ++i) {
                if (i % 3 == 0) {
                    ar[i] = ar1[i/3];
                }
                if (i % 3 == 1) {
                    ar[i] = ar2[i/3];
                }
                if (i % 3 == 2) {
                    ar[i] = ar3[i/3];
                }
            }
        }
        ll res = 0;
        for (int i = 0; i < n; ++i) {
            res ^= ((i+1)*ar[i]);
        }
        cout << res << '\n';
    }
    return 0;
}

发表评论