Problem Description#
Find the number of occurrences of the template string in .
Where .
Time limit is seconds, memory limit is MB, and input can only be read once.
Solution#
KMP template problem
The problem asks you to perform streaming string matching in a space-limited environment. In this problem, both the template string and the query string cannot fit into memory, so you must read them one by one in a streaming manner.
First, recall how the Rabin-Karp algorithm works. It calculates a hash value for the template string and uses a sliding window to maintain the hash values of all substrings of length in the query string. The time complexity here is , but the space complexity is (the entire window needs to be stored to remove the starting character). Obviously, this cannot meet the requirements of this problem.
The approach in the official solution:
- Read the template string and calculate the hash value of the prefix of length and the hash value of the complete string, with a time complexity of and a space complexity of .
- Read the query string and maintain a sliding window of size , which can find all positions in the query string that match this square root prefix. In other words, some positions that are impossible to match are filtered out, and some positions that need to calculate the hash value of the full length are left.
- Note that until this point, there has been no essential improvement. First, there are still many matching positions, and second, it is difficult to obtain the hash value of the full string.
- According to the (Weak) Periodicity Lemma, we can divide the matching positions into at most groups of arithmetic sequences. Specifically, all matching substrings in each group form a sequential list with overlaps, and the distance between adjacent positions in the list is equal (which is the so-called arithmetic sequence, and further, the overlapping part is actually the border of the square root prefix). As a result, the space for recording the matching positions is compressed to .
- The next problem is that we not only need to record the matching positions, but also need to record the prefix hash values at the matching positions, so that we can compare the full string when processing the corresponding ending position. Similarly, according to the conclusion of the above lemma, the hash values between adjacent two matching positions within the same group are fixed. We only need to record the starting point, common difference, ending point, hash value of the starting point, hash value of the corresponding string part of the common difference, to represent all the position information and hash value information of this arithmetic sequence.
Finally, the time complexity is , and the space complexity is . The first is to maintain the sliding window, and the second is to maintain all arithmetic sequences.
References for periodicity-related theories:
- Jin Ce, "Selected Topics in String Algorithms"
- 2019 Training Team Paper, Chen Sunli, "Related Algorithms and Applications of Substring Period Query Problem"
Official Solution#
Block KMP#
An approach without periods is to divide the pattern string into blocks of sqrt characters, and then run KMP on the large string about the large character.
Code#
#include <cmath>
#include <cstdio>
#include <random>
#include <vector>
using namespace std;
const int base = 131;
const int SZ = 3162;
const int cap = 4096;
namespace {
// biv = base^{-1}, bsziv = base^{-SZ+1}
int mod, biv, bsizv;
random_device rd;
mt19937 rnd(rd());
uniform_int_distribution<> gen(100000000, 900000000);
bool isPrime(int n) {
if (n % 2 == 0) return false;
int sq = (int) sqrt(n) + 1;
for (int i = 3; i <= sq; i += 2) {
if (n % i == 0) return false;
}
return true;
}
int add(int a, int b) {
a += b;
if (a >= mod) a -= mod;
return a;
}
int sub(int a, int b) {
a -= b;
if (a < 0) a += mod;
return a;
}
int mul(int a, int b) {
return 1ll * a * b % mod;
}
int qpow(int x, int n) {
int r = 1;
while (n) {
if (n & 1) r = mul(r, x);
n >>= 1;
x = mul(x, x);
}
return r;
}
void init() {
while (true) {
mod = gen(rnd);
if (isPrime(mod)) {
break;
}
}
biv = qpow(base, mod - 2);
bsizv = qpow(qpow(base, SZ - 1), mod - 2);
}
}
int n, m, preh = 0, allh = 0;
struct Ring {
char buf[cap];
int head = 0, tail = 0, size = 0, len;
int hsh = 0, xp = 1;
void init(int n) {
len = n;
hsh = 0;
head = tail = size = 0;
xp = 1;
}
char pop() {
size--;
char x = buf[(head++) % cap];
return x;
}
void push(char x) {
size++;
buf[(tail++) % cap] = x;
}
void append(char c) {
push(c);
hsh = add(mul(c, xp), hsh);
if (size > len) {
hsh = sub(hsh, pop());
hsh = mul(hsh, biv);
} else {
xp = mul(xp, base);
}
}
} f;
struct Per {
int start, delta = -1, end = -1;
int xp, hsh, dhsh = -1;
Per(int p, int x, int h) : start(p + 1 - SZ) {
end = start;
xp = mul(x, bsizv);
hsh = sub(h, mul(preh, xp));
}
bool next(int p, int curx, int curh) {
p = p + 1 - SZ;
if (delta == -1) {
if (p - start >= SZ) {
// it should have overlap part
return false;
} else {
// set delta
end = p;
delta = end - start;
curh = sub(curh, mul(preh, mul(curx, bsizv)));
dhsh = sub(curh, hsh);
return true;
}
} else {
if (p - end == delta) {
end = p;
return true;
} else {
return false;
}
}
}
bool match(int pos, int curv) {
if (start + n - 1 == pos) {
int target = sub(curv, hsh);
bool ok = target == mul(allh, xp);
if (delta != -1) {
start += delta;
int dxp = qpow(base, delta);
xp = mul(xp, dxp);
hsh = add(hsh, dhsh);
dhsh = mul(dhsh, dxp);
}
return ok;
}
return false;
}
};
int main() {
init();
scanf("%d%d", &n, &m);
getchar(); // end of line
for (int i = 1, xp = 1; i <= n; i++, xp = mul(xp, base)) {
char c = getchar();
int val = mul(c, xp);
if (i <= SZ) {
preh = add(preh, val);
}
allh = add(allh, val);
}
getchar(); // end of line
f.init(n <= SZ ? n : SZ);
int ans = 0, curv = 0, matched = 0;
vector<Per> ps;
for (int i = 1, xp = 1; i <= m; i++, xp = mul(xp, base)) {
char c = getchar();
f.append(c);
if (n <= SZ) {
ans += f.size == n && f.hsh == allh;
} else {
curv = add(curv, mul(c, xp));
if (f.size == SZ && f.hsh == preh) {
// match the sqrt prefix
// extend the last group or create a new group
if (ps.empty() || !ps.back().next(i, xp, curv)) {
ps.emplace_back(i, xp, curv);
}
}
if (matched < ps.size()) {
ans += ps[matched].match(i, curv);
if (ps[matched].end + n - 1 == i) {
matched++;
}
}
}
}
printf("%d\n", ans);
return 0;
}