Fu_L's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Fu-L/cp-library

:heavy_check_mark: AhoCorasick
(src/string/aho_corasick.hpp)

AhoCorasick

複数のパターン文字列に対する文字列マッチングを行います.

コンストラクタ

AhoCorasick<size_t X = 26, char margin = 'a', bool heavy = true> aho()

計算量

insertなど

void insert(string s, int x)

Trie を継承しているため Trie にある関数が使えます.

aho にパターン文字列 $s$ を識別子 $x$ として追加します.

計算量

$s$ の長さを $n$ として,

build

void build()

追加されたパターン文字列集合をもとにオートマトンを構築します.

制約

計算量

match

unordered_map<int, long long> match(string s, int pos = 0)

状態 pos から始めて,テキスト文字列 $s$ が各パターン文字列とマッチした回数を返します.
heavytrue のときは unordered_map を返しますが, false のときは long long を返すので注意してください.

制約

計算量

$s$ の長さを $n$ として,

move

(1) pair<long long, int> move(char c, int pos = 0)
(2) pair<long long, int> move(string s, int pos = 0)

状態 pos から始めて,次に文字 $c$ または文字列 $s$ が与えられたときに新たにパターン文字列にマッチした個数と次状態を返します.

制約

計算量

Depends on

Verified with

Code

#pragma once
#include "../template/template.hpp"
#include "./trie.hpp"
template <size_t X = 26, char margin = 'a', bool heavy = true>
struct AhoCorasick : Trie<X + 1, margin> {
    using TRIE = Trie<X + 1, margin>;
    using TRIE::next;
    using TRIE::st;
    using TRIE::TRIE;
    vector<int> cnt;
    void build() {
        const int n = (int)st.size();
        cnt.resize(n);
        for(int i = 0; i < n; ++i) {
            if(heavy) sort(st[i].idxs.begin(), st[i].idxs.end());
            cnt[i] = (int)st[i].idxs.size();
        }
        queue<int> que;
        for(int i = 0; i < (int)X; ++i) {
            if(~next(0, i)) {
                next(next(0, i), X) = 0;
                que.emplace(next(0, i));
            } else {
                next(0, i) = 0;
            }
        }
        while(!que.empty()) {
            auto& x = st[que.front()];
            int fail = x.nxt[X];
            cnt[que.front()] += cnt[fail];
            que.pop();
            for(int i = 0; i < (int)X; ++i) {
                int& nx = x.nxt[i];
                if(nx < 0) {
                    nx = next(fail, i);
                    continue;
                }
                que.emplace(nx);
                next(nx, X) = next(fail, i);
                if(heavy) {
                    auto& idx = st[nx].idxs;
                    auto& idy = st[next(fail, i)].idxs;
                    vector<int> idz;
                    set_union(idx.begin(), idx.end(), idy.begin(), idy.end(), back_inserter(idz));
                    idx = idz;
                }
            }
        }
    }
    conditional_t<heavy, unordered_map<int, long long>, long long> match(const string& s, int pos = 0) {
        unordered_map<int, int> pos_cnt;
        for(const auto& c : s) {
            pos = next(pos, c - margin);
            ++pos_cnt[pos];
        }
        conditional_t<heavy, unordered_map<int, long long>, long long> res{};
        for(const auto& [key, val] : pos_cnt) {
            if constexpr(heavy) {
                for(const auto& x : st[key].idxs) res[x] += val;
            } else {
                res += 1ll * cnt[key] * val;
            }
        }
        return res;
    }
    pair<long long, int> move(const char c, int pos = 0) {
        pos = next(pos, c - margin);
        return {cnt[pos], pos};
    }
    pair<long long, int> move(const string& s, int pos = 0) {
        long long sum = 0;
        for(const char c : s) {
            auto nxt = move(c, pos);
            sum += nxt.first;
            pos = nxt.second;
        }
        return {sum, pos};
    }
    int count(const int pos) const {
        return cnt[pos];
    }
};
#line 2 "src/template/template.hpp"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using P = pair<long long, long long>;
#define rep(i, a, b) for(long long i = (a); i < (b); ++i)
#define rrep(i, a, b) for(long long i = (a); i >= (b); --i)
constexpr long long inf = 4e18;
struct SetupIO {
    SetupIO() {
        ios::sync_with_stdio(0);
        cin.tie(0);
        cout << fixed << setprecision(30);
    }
} setup_io;
#line 3 "src/string/trie.hpp"
template <size_t X = 26, char margin = 'a'>
struct Trie {
    struct Node {
        array<int, X> nxt;
        vector<int> idxs;
        int idx, count, parent;
        char key;
        Node(const char c, const int par)
            : idx(-1), count(0), parent(par), key(c) {
            fill(nxt.begin(), nxt.end(), -1);
        }
    };
    vector<Node> st;
    Trie(const char c = '$', const int p = -1) {
        st.emplace_back(c, p);
    }
    inline int& next(const int i, const int j) {
        assert(0 <= i and i < (int)st.size());
        assert(0 <= j and j < (int)X);
        return st[i].nxt[j];
    }
    void insert(const string& s, const int x) {
        int pos = 0;
        for(int i = 0; i < (int)s.size(); ++i) {
            ++st[pos].count;
            const int k = s[i] - margin;
            if(~next(pos, k)) {
                pos = next(pos, k);
                continue;
            }
            const int npos = st.size();
            next(pos, k) = npos;
            st.emplace_back(s[i], pos);
            pos = npos;
        }
        ++st[pos].count;
        st[pos].idx = x;
        st[pos].idxs.emplace_back(x);
    }
    int find(const string& s) {
        int pos = 0;
        for(int i = 0; i < (int)s.size(); ++i) {
            const int k = s[i] - margin;
            if(next(pos, k) < 0) return -1;
            pos = next(pos, k);
        }
        return pos;
    }
    int move(const int pos, const char c) {
        assert(0 <= pos and pos < (int)st.size());
        return next(pos, c - margin);
    }
    int size() const {
        return st.size();
    }
    int idx(const int pos) const {
        assert(0 <= pos and pos < (int)st.size());
        return st[pos].idx;
    }
    int count(const int pos) const {
        assert(0 <= pos and pos < (int)st.size());
        return st[pos].count;
    }
    int par(const int pos) const {
        assert(0 <= pos and pos < (int)st.size());
        return st[pos].parent;
    }
    vector<int> idxs(const int pos) const {
        assert(0 <= pos and pos < (int)st.size());
        return st[pos].idxs;
    }
};
#line 4 "src/string/aho_corasick.hpp"
template <size_t X = 26, char margin = 'a', bool heavy = true>
struct AhoCorasick : Trie<X + 1, margin> {
    using TRIE = Trie<X + 1, margin>;
    using TRIE::next;
    using TRIE::st;
    using TRIE::TRIE;
    vector<int> cnt;
    void build() {
        const int n = (int)st.size();
        cnt.resize(n);
        for(int i = 0; i < n; ++i) {
            if(heavy) sort(st[i].idxs.begin(), st[i].idxs.end());
            cnt[i] = (int)st[i].idxs.size();
        }
        queue<int> que;
        for(int i = 0; i < (int)X; ++i) {
            if(~next(0, i)) {
                next(next(0, i), X) = 0;
                que.emplace(next(0, i));
            } else {
                next(0, i) = 0;
            }
        }
        while(!que.empty()) {
            auto& x = st[que.front()];
            int fail = x.nxt[X];
            cnt[que.front()] += cnt[fail];
            que.pop();
            for(int i = 0; i < (int)X; ++i) {
                int& nx = x.nxt[i];
                if(nx < 0) {
                    nx = next(fail, i);
                    continue;
                }
                que.emplace(nx);
                next(nx, X) = next(fail, i);
                if(heavy) {
                    auto& idx = st[nx].idxs;
                    auto& idy = st[next(fail, i)].idxs;
                    vector<int> idz;
                    set_union(idx.begin(), idx.end(), idy.begin(), idy.end(), back_inserter(idz));
                    idx = idz;
                }
            }
        }
    }
    conditional_t<heavy, unordered_map<int, long long>, long long> match(const string& s, int pos = 0) {
        unordered_map<int, int> pos_cnt;
        for(const auto& c : s) {
            pos = next(pos, c - margin);
            ++pos_cnt[pos];
        }
        conditional_t<heavy, unordered_map<int, long long>, long long> res{};
        for(const auto& [key, val] : pos_cnt) {
            if constexpr(heavy) {
                for(const auto& x : st[key].idxs) res[x] += val;
            } else {
                res += 1ll * cnt[key] * val;
            }
        }
        return res;
    }
    pair<long long, int> move(const char c, int pos = 0) {
        pos = next(pos, c - margin);
        return {cnt[pos], pos};
    }
    pair<long long, int> move(const string& s, int pos = 0) {
        long long sum = 0;
        for(const char c : s) {
            auto nxt = move(c, pos);
            sum += nxt.first;
            pos = nxt.second;
        }
        return {sum, pos};
    }
    int count(const int pos) const {
        return cnt[pos];
    }
};
Back to top page