MO’s Algorithm (Query Square Root Decomposition)

Hello Everyone,

Let us consider the following problem to understand MO’s Algorithm.
We are given an array and a set of query ranges, we are required to find the sum of every query range.

Example:

Input: arr[] = {1, 1, 2, 1, 3, 4, 5, 2, 8};
query[] = [0, 4], [1, 3] [2, 4]
Output: Sum of arr[] elements in range [0, 4] is 8
Sum of arr[] elements in range [1, 3] is 4
Sum of arr[] elements in range [2, 4] is 6

A solution is to run a loop from L to R and calculate the sum of elements in given range for every query [L, R]

// Program to compute sum of ranges for different range

// queries.

#include <bits/stdc++.h>

using namespace std;

// Structure to represent a query range

struct Query

{

int L, R;

};

// Prints sum of all query ranges. m is number of queries

// n is the size of the array.

void printQuerySums( int a[], int n, Query q[], int m)

{

// One by one compute sum of all queries

for ( int i=0; i<m; i++)

{

// Left and right boundaries of current range

int L = q[i].L, R = q[i].R;

// Compute sum of current query range

int sum = 0;

for ( int j=L; j<=R; j++)

sum += a[j];

// Print sum of current query range

cout << "Sum of [" << L << ", "

<< R << "] is " << sum << endl;

}

}

// Driver program

int main()

{

int a[] = {1, 1, 2, 1, 3, 4, 5, 2, 8};

int n = sizeof (a)/ sizeof (a[0]);

Query q[] = {{0, 4}, {1, 3}, {2, 4}};

int m = sizeof (q)/ sizeof (q[0]);

printQuerySums(a, n, q, m);

return 0;

}

Output:

Sum of [0, 4] is 8 Sum of [1, 3] is 4 Sum of [2, 4] is 6

The time complexity of above solution is O(mn).
The idea of MO’s algorithm is to pre-process all queries so that result of one query can be used in next query. Below are steps.

Let a[0…n-1] be input array and q[0…m-1] be array of queries.

  1. Sort all queries in a way that queries with L values from 0 to √n – 1 are put together, then all queries from √n to 2*√n – 1 , and so on. All queries within a block are sorted in increasing order of R values.
  2. Process all queries one by one in a way that every query uses sum computed in the previous query.
  • Let ‘sum’ be sum of previous query.
  • Remove extra elements of previous query. For example if previous query is [0, 8] and current query is [3, 9], then we subtract a[0],a[1] and a[2] from sum
  • Add new elements of current query. In the same example as above, we add a[9] to sum.

The great thing about this algorithm is, in step 2, index variable for R change at most O(n * √n) times throughout the run and same for L changes its value at most O(m * √n) times (See below, after the code, for details). All these bounds are possible only because the queries are sorted first in blocks of √n size.

The preprocessing part takes O(m Log m) time.

Processing all queries takes O(n * √n) + O(m * √n) = O((m+n) * √n) time.

Below is the implementation of the above idea.

// Program to compute sum of ranges for different range

// queries

#include <bits/stdc++.h>

using namespace std;

// Variable to represent block size. This is made global

// so compare() of sort can use it.

int block;

// Structure to represent a query range

struct Query

{

int L, R;

};

// Function used to sort all queries so that all queries

// of the same block are arranged together and within a block,

// queries are sorted in increasing order of R values.

bool compare(Query x, Query y)

{

// Different blocks, sort by block.

if (x.L/block != y.L/block)

return x.L/block < y.L/block;

// Same block, sort by R value

return x.R < y.R;

}

// Prints sum of all query ranges. m is number of queries

// n is size of array a[].

void queryResults( int a[], int n, Query q[], int m)

{

// Find block size

block = ( int ) sqrt (n);

// Sort all queries so that queries of same blocks

// are arranged together.

sort(q, q + m, compare);

// Initialize current L, current R and current sum

int currL = 0, currR = 0;

int currSum = 0;

// Traverse through all queries

for ( int i=0; i<m; i++)

{

// L and R values of current range

int L = q[i].L, R = q[i].R;

// Remove extra elements of previous range. For

// example if previous range is [0, 3] and current

// range is [2, 5], then a[0] and a[1] are subtracted

while (currL < L)

{

currSum -= a[currL];

currL++;

}

// Add Elements of current Range

while (currL > L)

{

currSum += a[currL-1];

currL--;

}

while (currR <= R)

{

currSum += a[currR];

currR++;

}

// Remove elements of previous range. For example

// when previous range is [0, 10] and current range

// is [3, 8], then a[9] and a[10] are subtracted

while (currR > R+1)

{

currSum -= a[currR-1];

currR--;

}

// Print sum of current range

cout << "Sum of [" << L << ", " << R

<< "] is " << currSum << endl;

}

}

// Driver program

int main()

{

int a[] = {1, 1, 2, 1, 3, 4, 5, 2, 8};

int n = sizeof (a)/ sizeof (a[0]);

Query q[] = {{0, 4}, {1, 3}, {2, 4}};

int m = sizeof (q)/ sizeof (q[0]);

queryResults(a, n, q, m);

return 0;

}

Output:

Sum of [1, 3] is 4 Sum of [0, 4] is 8 Sum of [2, 4] is 6