OneKuma

OneKuma's Blog

One Lonely Kuma.
github
bilibili
twitter

2022 CCPC 绵阳現場大会 K パターンマッチング in A Minor Low Space 解説

問題#

テンプレート文字列 SSTT に出現する回数を求める。

ここで 1S,T1071 \le |S|, |T| \le 10^7

時間制限 55 秒、メモリ制限 1\mathbf{1} MB、かつ入力は一度だけ読み取ることができる

解法#

KMP テンプレート問題

問題は、メモリが制限された状況でストリーム文字列マッチングを行うことを求めています。この問題では、テンプレート文字列もクエリ文字列もメモリに収めることができず、一つずつストリームで読み取る必要があります。

まず、Rabin-Karp アルゴリズムがどのようなものであるかを思い出してください。これはテンプレート文字列 SS のハッシュ値を求め、その後スライディングウィンドウを使用してクエリ文字列のすべての長さ SS の部分文字列のハッシュ値を維持します。ここでの時間計算量は O(T)O(|T|) ですが、空間計算量は O(S)O(|S|) です(ウィンドウ全体を保持する必要があるため、先頭の文字を削除するために)。明らかにこの問題の要件を満たすことはできません。

公式解法の手法:

  1. テンプレート文字列 SS を読み込み、長さ n\lfloor \sqrt{n} \rfloor の前方ハッシュ値と完全な文字列のハッシュ値を計算します。時間計算量 O(S)O(|S|)、空間計算量 O(1)O(1)
  2. クエリ文字列 TT を読み込み、n\lfloor \sqrt{n} \rfloor サイズのスライディングウィンドウを維持し、この平方根前方に一致するすべての位置を求めることができます。つまり、一部の必ず一致しない位置をフィルタリングし、全長のハッシュ値を求める必要がある位置を残します。
    • 注意すべきは、ここまで実際には本質的な改善がないことです。一つは一致した位置がまだ多く、もう一つは対応する全文字列のハッシュ値を得るのが難しいことです。
    • (Weak) Periodicity Lemm に基づいて、一致した位置を最大 n\lceil \sqrt{n} \rceil グループの等差数列に分けることができます。具体的には、各グループ内のすべての一致した部分文字列は順次重なり合ったリストであり、隣接する出現位置の距離差は等しい(いわゆる等差数列を構成し、さらに重なり部分は実際には平方根前方のボーダーです)。したがって、一致した位置を記録するための空間は O(n)O(\lfloor \sqrt{n} \rfloor) に圧縮されました。
    • 次の問題は、一致した位置を記録するだけでなく、一致した位置の前方ハッシュ値も記録する必要があることです。次に、その可能性のある一致の終了位置を処理する際に全文字列の比較を行います。同様に、上記の引理の結論に基づいて、同じグループの等差数列内では、隣接する 2 つの一致位置間のハッシュ値は固定されており、等差数列の起点、公差、終点、起点のハッシュ値、公差に対応する文字列部分のハッシュ値を記録することで、この等差数列のすべての位置情報とハッシュ値情報を表現できます。

最終的に、時間計算量は O(n)O(n)、空間計算量は O(n)O(\lfloor \sqrt{n} \rfloor) であり、スライディングウィンドウを維持することと、すべての等差数列を維持することが含まれます。

周期に関連する理論の参考文献:

  • 金策,《文字列アルゴリズム選講》
  • 2019 年集訓隊論文,陳孫立,《部分文字列周期クエリ問題の関連アルゴリズム及びその応用》

公式解法#

official tutorial

分割 KMP#

一つの周期を持たない方法は、パターン文字列を平方根個の文字に分割し、各ブロックを一つの大きな文字に統合し、その大きな文字について平方根個の並列 KMP を実行することです。

コード#

#include <cmath>
#include <cstdio>
#include <random>
#include <vector>

using namespace std;

const int base = 131;
const int SZ = 3162;
const int cap = 4096;

namespace {
  // biv = base^{-1}, bsziv = base^{-SZ+1}
  int mod, biv, bsizv;

  random_device rd;
  mt19937 rnd(rd());
  uniform_int_distribution<> gen(100000000, 900000000);

  bool isPrime(int n) {
    if (n % 2 == 0) return false;
    int sq = (int) sqrt(n) + 1;
    for (int i = 3; i <= sq; i += 2) {
      if (n % i == 0) return false;
    }
    return true;
  }

  int add(int a, int b) {
    a += b;
    if (a >= mod) a -= mod;
    return a;
  }

  int sub(int a, int b) {
    a -= b;
    if (a < 0) a += mod;
    return a;
  }

  int mul(int a, int b) {
    return 1ll * a * b % mod;
  }

  int qpow(int x, int n) {
    int r = 1;
    while (n) {
      if (n & 1) r = mul(r, x);
      n >>= 1;
      x = mul(x, x);
    }
    return r;
  }

  void init() {
    while (true) {
      mod = gen(rnd);
      if (isPrime(mod)) {
        break;
      }
    }
    biv = qpow(base, mod - 2);
    bsizv = qpow(qpow(base, SZ - 1), mod - 2);
  }
}

int n, m, preh = 0, allh = 0;

struct Ring {
  char buf[cap];
  int head = 0, tail = 0, size = 0, len;
  int hsh = 0, xp = 1;

  void init(int n) {
    len = n;
    hsh = 0;
    head = tail = size = 0;
    xp = 1;
  }

  char pop() {
    size--;
    char x = buf[(head++) % cap];
    return x;
  }

  void push(char x) {
    size++;
    buf[(tail++) % cap] = x;
  }

  void append(char c) {
    push(c);
    hsh = add(mul(c, xp), hsh);
    if (size > len) {
      hsh = sub(hsh, pop());
      hsh = mul(hsh, biv);
    } else {
      xp = mul(xp, base);
    }
  }
} f;

struct Per {
  int start, delta = -1, end = -1;
  int xp, hsh, dhsh = -1;

  Per(int p, int x, int h) : start(p + 1 - SZ) {
    end = start;
    xp = mul(x, bsizv);
    hsh = sub(h, mul(preh, xp));
  }

  bool next(int p, int curx, int curh) {
    p = p + 1 - SZ;
    if (delta == -1) {
      if (p - start >= SZ) {
        // 重なり部分が必要です
        return false;
      } else {
        // delta を設定
        end = p;
        delta = end - start;
        curh = sub(curh, mul(preh, mul(curx, bsizv)));
        dhsh = sub(curh, hsh);
        return true;
      }
    } else {
      if (p - end == delta) {
        end = p;
        return true;
      } else {
        return false;
      }
    }
  }

  bool match(int pos, int curv) {
    if (start + n - 1 == pos) {
      int target = sub(curv, hsh);
      bool ok = target == mul(allh, xp);
      if (delta != -1) {
        start += delta;
        int dxp = qpow(base, delta);
        xp = mul(xp, dxp);
        hsh = add(hsh, dhsh);
        dhsh = mul(dhsh, dxp);
      }
      return ok;
    }
    return false;
  }
};

int main() {
  init();

  scanf("%d%d", &n, &m);
  getchar(); // 行の終わり

  for (int i = 1, xp = 1; i <= n; i++, xp = mul(xp, base)) {
    char c = getchar();
    int val = mul(c, xp);
    if (i <= SZ) {
      preh = add(preh, val);
    }
    allh = add(allh, val);
  }
  getchar(); // 行の終わり

  f.init(n <= SZ ? n : SZ);
  int ans = 0, curv = 0, matched = 0;
  vector<Per> ps;
  for (int i = 1, xp = 1; i <= m; i++, xp = mul(xp, base)) {
    char c = getchar();
    f.append(c);
    if (n <= SZ) {
      ans += f.size == n && f.hsh == allh;
    } else {
      curv = add(curv, mul(c, xp));
      if (f.size == SZ && f.hsh == preh) {
        // 平方根前方に一致
        // 最後のグループを拡張するか、新しいグループを作成
        if (ps.empty() || !ps.back().next(i, xp, curv)) {
          ps.emplace_back(i, xp, curv);
        }
      }
      if (matched < ps.size()) {
        ans += ps[matched].match(i, curv);
        if (ps[matched].end + n - 1 == i) {
          matched++;
        }
      }
    }
  }

  printf("%d\n", ans);

  return 0;
}
読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。