OneKuma

OneKuma's Blog

One Lonely Kuma.
github
bilibili
twitter

2022 CCPC 綿陽現場賽 K Pattern Matching in A Minor Low Space 題解

題面#

求模板串 SSTT 中的出現次數。

其中 1S,T1071 \le |S|, |T| \le 10^7.

時間限制 55 秒,內存限制 1\mathbf{1} MB, 且輸入只能讀取一次.

題解#

KMP 模板題

題目讓你在一個空間受限的背景下,進行流式字符串匹配. 本題無論模板串還是詢問串都無法放進內存,你必須流式的一個一個讀取.

首先,回憶 Rabin-Karp 算法是怎麼樣的。它對模板串 SS 求了一個 hash 值,然後使用滑動窗口 維護詢問串所有長度 SS 的子串的 hash 值,這裡時間複雜度是 O(T)O(|T|) 的,但是空間複雜度是 O(S)O(|S|) 的 (需要存下整個窗口,以供刪除開頭的字符). 顯然無法滿足本題的要求.

官方題解的做法:

  1. 讀入模板串 SS, 計算出長度為 n\lfloor \sqrt{n} \rfloor 的前綴 hash 值和完整串的 hash 值,時間複雜度 O(S)O(|S|), 空間複雜度 O(1)O(1);
  2. 讀入詢問串 TT, 維護 n\lfloor \sqrt{n} \rfloor 大小的滑動窗口,可以求出詢問串中所有和這個根號前綴匹配的位置。相當於,篩出去了一些必不可能匹配上的位置,留下來一些位置需要求出相應的全長度的 hash 值.
    • 注意到,直到這裡實際上還沒有本質的改善,一是匹配上的位置仍然很多,二是難以搞出相應的全串 hash 值.
    • 根據 (Weak) Periodicity Lemm, 我們可以把匹配上的位置分成至多 n\lceil \sqrt{n} \rceil等差數列. 具體的,每組中所有匹配的子串是一個順次有重疊的列表,相鄰出現位置的距離差相等 (也就是所謂的構成等差數列,進一步,重疊部分其實就是根號前綴的 border). 於是,記錄匹配位置的空間被壓縮到了 O(n)O(\lfloor \sqrt{n} \rfloor).
    • 下一個問題是,我們不僅需要記錄匹配的位置,還需要記錄匹配處的前綴 hash 值,等到處理到該次可能匹配的結束位置時來進行全串的比對。同樣根據上述引理的結論,同一組等差數列內部,相鄰 2 個匹配位置之間的 hash 值是固定的,我們只需要記錄等差數列的起點,公差,末點,起點的 hash 值,公差對應字符串部分的 hash 值, 就能表示出這一個等差數列的所有位置信息和 hash 值信息.

最終,時間複雜度 O(n)O(n), 空間複雜度 O(n)O(\lfloor \sqrt{n} \rfloor), 一是維護滑動窗口,二是維護所有等差數列.

週期相關理論的參考文獻:

  • 金策,《字符串算法選講》
  • 2019 年集訓隊論文,陳孫立,《子串周期查詢問題的相關算法及其應用》

官方題解#

official tutorial

分塊 KMP#

一個不要 period 的做法,大概是把模式串分成 sqrt 個字符一塊,每塊合併成一個大字符,然後關於大字符跑 sqrt 個並排的 kmp

代碼#

#include <cmath>
#include <cstdio>
#include <random>
#include <vector>

using namespace std;

const int base = 131;
const int SZ = 3162;
const int cap = 4096;

namespace {
  // biv = base^{-1}, bsziv = base^{-SZ+1}
  int mod, biv, bsizv;

  random_device rd;
  mt19937 rnd(rd());
  uniform_int_distribution<> gen(100000000, 900000000);

  bool isPrime(int n) {
    if (n % 2 == 0) return false;
    int sq = (int) sqrt(n) + 1;
    for (int i = 3; i <= sq; i += 2) {
      if (n % i == 0) return false;
    }
    return true;
  }

  int add(int a, int b) {
    a += b;
    if (a >= mod) a -= mod;
    return a;
  }

  int sub(int a, int b) {
    a -= b;
    if (a < 0) a += mod;
    return a;
  }

  int mul(int a, int b) {
    return 1ll * a * b % mod;
  }

  int qpow(int x, int n) {
    int r = 1;
    while (n) {
      if (n & 1) r = mul(r, x);
      n >>= 1;
      x = mul(x, x);
    }
    return r;
  }

  void init() {
    while (true) {
      mod = gen(rnd);
      if (isPrime(mod)) {
        break;
      }
    }
    biv = qpow(base, mod - 2);
    bsizv = qpow(qpow(base, SZ - 1), mod - 2);
  }
}

int n, m, preh = 0, allh = 0;

struct Ring {
  char buf[cap];
  int head = 0, tail = 0, size = 0, len;
  int hsh = 0, xp = 1;

  void init(int n) {
    len = n;
    hsh = 0;
    head = tail = size = 0;
    xp = 1;
  }

  char pop() {
    size--;
    char x = buf[(head++) % cap];
    return x;
  }

  void push(char x) {
    size++;
    buf[(tail++) % cap] = x;
  }

  void append(char c) {
    push(c);
    hsh = add(mul(c, xp), hsh);
    if (size > len) {
      hsh = sub(hsh, pop());
      hsh = mul(hsh, biv);
    } else {
      xp = mul(xp, base);
    }
  }
} f;

struct Per {
  int start, delta = -1, end = -1;
  int xp, hsh, dhsh = -1;

  Per(int p, int x, int h) : start(p + 1 - SZ) {
    end = start;
    xp = mul(x, bsizv);
    hsh = sub(h, mul(preh, xp));
  }

  bool next(int p, int curx, int curh) {
    p = p + 1 - SZ;
    if (delta == -1) {
      if (p - start >= SZ) {
        // it should have overlap part
        return false;
      } else {
        // set delta
        end = p;
        delta = end - start;
        curh = sub(curh, mul(preh, mul(curx, bsizv)));
        dhsh = sub(curh, hsh);
        return true;
      }
    } else {
      if (p - end == delta) {
        end = p;
        return true;
      } else {
        return false;
      }
    }
  }

  bool match(int pos, int curv) {
    if (start + n - 1 == pos) {
      int target = sub(curv, hsh);
      bool ok = target == mul(allh, xp);
      if (delta != -1) {
        start += delta;
        int dxp = qpow(base, delta);
        xp = mul(xp, dxp);
        hsh = add(hsh, dhsh);
        dhsh = mul(dhsh, dxp);
      }
      return ok;
    }
    return false;
  }
};

int main() {
  init();

  scanf("%d%d", &n, &m);
  getchar(); // end of line

  for (int i = 1, xp = 1; i <= n; i++, xp = mul(xp, base)) {
    char c = getchar();
    int val = mul(c, xp);
    if (i <= SZ) {
      preh = add(preh, val);
    }
    allh = add(allh, val);
  }
  getchar(); // end of line

  f.init(n <= SZ ? n : SZ);
  int ans = 0, curv = 0, matched = 0;
  vector<Per> ps;
  for (int i = 1, xp = 1; i <= m; i++, xp = mul(xp, base)) {
    char c = getchar();
    f.append(c);
    if (n <= SZ) {
      ans += f.size == n && f.hsh == allh;
    } else {
      curv = add(curv, mul(c, xp));
      if (f.size == SZ && f.hsh == preh) {
        // match the sqrt prefix
        // extend the last group or create a new group
        if (ps.empty() || !ps.back().next(i, xp, curv)) {
          ps.emplace_back(i, xp, curv);
        }
      }
      if (matched < ps.size()) {
        ans += ps[matched].match(i, curv);
        if (ps[matched].end + n - 1 == i) {
          matched++;
        }
      }
    }
  }

  printf("%d\n", ans);

  return 0;
}
載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。