Given an array arr[] of size N, the task is to count the number of longest increasing subsequences present in the given array.
Example:
Input: arr[] = {2, 2, 2, 2, 2} Output: 5 Explanation: The length of the longest increasing subsequence is 1, i.e. {2}. Therefore, count of longest increasing subsequences of length 1 is 5.
Input: arr[] = {1, 3, 5, 4, 7} Output: 2 Explanation: The length of the longest increasing subsequence is 4, and there are 2 longest increasing subsequences of length 4, i.e. {1, 3, 4, 7} and {1, 3, 5, 7}.
Approach: An approach to the given problem has been already discussed using dynamic programming in this article. This article suggests a different approach using segment trees. Follow the below steps to solve the given problem:
- Initialise the segment tree as an array of pairs initially containing pairs of (0, 0), where the 1st element represents the length of LIS and 2nd element represents the count of LIS of current length.
- The 1st element of the segment tree can be calculated similarly to the approach discussed in this article.
- The 2nd element of the segment tree can be calculated using the following steps:
- If cases where the length of left child > length of right child, the parent node becomes equal to the left child as LIS will that be of the left child.
- If cases where the length of left child < length of right child, the parent node becomes equal to the right child as LIS will that be of the right child.
- If cases where the length of left child = length of right child, the parent node becomes equal to the sum of the count of LIS of the left child and the right child.
- The required answer is the 2nd element of the root of the segment tree.
Below is the implementation of the above approach:
C++
#include <bits/stdc++.h>
using namespace std;
#define M 100000
vector<pair< int , int > > tree(4 * M + 1);
void update_tree( int start, int end,
int update_idx, int length_t,
int count_c, int idx)
{
if (start == end
&& start == update_idx) {
tree[idx].first
= max(tree[idx].first, length_t);
tree[idx].second = count_c;
return ;
}
if (update_idx < start
|| end < update_idx) {
return ;
}
int mid = (start + end) / 2;
update_tree(start, mid, update_idx,
length_t, count_c,
2 * idx);
update_tree(mid + 1, end, update_idx,
length_t, count_c,
2 * idx + 1);
if (tree[2 * idx].first
== tree[2 * idx + 1].first) {
tree[idx].first
= tree[2 * idx].first;
tree[idx].second
= tree[2 * idx].second
+ tree[2 * idx + 1].second;
}
else if (tree[2 * idx].first
> tree[2 * idx + 1].first) {
tree[idx] = tree[2 * idx];
}
else {
tree[idx] = tree[2 * idx + 1];
}
}
pair< int , int > query( int start, int end,
int query_start,
int query_end, int idx)
{
if (query_start <= start
&& end <= query_end) {
return tree[idx];
}
pair< int , int > temp({ INT32_MIN, 0 });
if (end < query_start
|| query_end < start) {
return temp;
}
int mid = (start + end) / 2;
auto left_child
= query(start, mid, query_start,
query_end, 2 * idx);
auto right_child
= query(mid + 1, end, query_start,
query_end, 2 * idx + 1);
if (left_child.first > right_child.first) {
return left_child;
}
if (right_child.first > left_child.first) {
return right_child;
}
return make_pair(left_child.first,
left_child.second
+ right_child.second);
}
bool comp(pair< int , int > a, pair< int , int > b)
{
if (a.first == b.first) {
return a.second > b.second;
}
return a.first < b.first;
}
int countLIS( int arr[], int n)
{
vector<pair< int , int > > pair_array(n);
for ( int i = 0; i < n; i++) {
pair_array[i].first = arr[i];
pair_array[i].second = i;
}
sort(pair_array.begin(),
pair_array.end(), comp);
for ( int i = 0; i < n; i++) {
int update_idx = pair_array[i].second;
if (update_idx == 0) {
update_tree(0, n - 1, 0, 1, 1, 1);
continue ;
}
pair< int , int > temp
= query(0, n - 1, 0,
update_idx - 1, 1);
update_tree(0, n - 1, update_idx,
temp.first + 1,
max(1, temp.second), 1);
}
pair< int , int > ans
= query(0, n - 1, 0, n - 1, 1);
return ans.second;
}
int main()
{
int arr[] = { 1, 3, 5, 4, 7 };
int n = sizeof (arr) / sizeof ( int );
cout << countLIS(arr, n);
return 0;
}
|
Java
import java.util.*;
import java.io.*;
public class GFG{
public static int M = 100000 ;
public static ArrayList<ArrayList<Integer>> tree =
new ArrayList<ArrayList<Integer>>();
public static void update_tree( int start, int end,
int update_idx, int length_t,
int count_c, int idx)
{
if (start == end && start == update_idx) {
tree.get(idx).set( 0 , Math.max(tree.get(idx).get( 0 ), length_t));
tree.get(idx).set( 1 , count_c);
return ;
}
if (update_idx < start || end < update_idx) {
return ;
}
int mid = (start + end) / 2 ;
update_tree(start, mid, update_idx,
length_t, count_c, 2 * idx);
update_tree(mid + 1 , end, update_idx,
length_t, count_c, 2 * idx + 1 );
if (tree.get( 2 * idx).get( 0 ) == tree.get( 2 * idx + 1 ).get( 0 )) {
tree.set(idx, new ArrayList<Integer>(
List.of(tree.get( 2 * idx).get( 0 ),
tree.get( 2 * idx).get( 1 ) +
tree.get( 2 * idx + 1 ).get( 1 ))
));
}
else if (tree.get( 2 * idx).get( 0 ) > tree.get( 2 * idx + 1 ).get( 0 )) {
tree.set(idx, new ArrayList<Integer>(
List.of(tree.get( 2 * idx).get( 0 ), tree.get( 2 * idx).get( 1 ))
));
}
else {
tree.set(idx, new ArrayList<Integer>(
List.of(tree.get( 2 * idx + 1 ).get( 0 ), tree.get( 2 * idx + 1 ).get( 1 ))
));
}
}
public static ArrayList<Integer> query( int start, int end,
int query_start,
int query_end, int idx)
{
if (query_start <= start && end <= query_end) {
return new ArrayList<Integer>(tree.get(idx));
}
ArrayList<Integer> temp = new ArrayList<Integer>(
List.of(Integer.MIN_VALUE, 0 )
);
if (end < query_start || query_end < start) {
return new ArrayList<Integer>(temp);
}
int mid = (start + end) / 2 ;
ArrayList<Integer> left_child = query(start, mid,
query_start,
query_end, 2 * idx);
ArrayList<Integer> right_child = query(mid + 1 , end,
query_start,
query_end, 2 * idx + 1 );
if (left_child.get( 0 ) > right_child.get( 0 )) {
return new ArrayList<Integer>(left_child);
}
if (right_child.get( 0 ) > left_child.get( 0 )) {
return new ArrayList<Integer>(right_child);
}
return new ArrayList<Integer>(
List.of(
left_child.get( 0 ),
left_child.get( 1 ) + right_child.get( 1 )
)
);
}
public static int countLIS( int arr[], int n)
{
ArrayList<ArrayList<Integer>> pair_array = new ArrayList<ArrayList<Integer>>();
for ( int i = 0 ; i < n ; i++){
pair_array.add( new ArrayList<Integer>(
List.of(arr[i], i)
));
}
Collections.sort(pair_array, new comp());
for ( int i = 0 ; i < n ; i++) {
int update_idx = pair_array.get(i).get( 1 );
if (update_idx == 0 ) {
update_tree( 0 , n - 1 , 0 , 1 , 1 , 1 );
continue ;
}
ArrayList<Integer> temp = query( 0 , n - 1 , 0 ,
update_idx - 1 , 1 );
update_tree( 0 , n - 1 , update_idx, temp.get( 0 ) + 1 ,
Math.max( 1 , temp.get( 1 )), 1 );
}
ArrayList<Integer> ans = query( 0 , n - 1 , 0 , n - 1 , 1 );
return ans.get( 1 );
}
public static void main(String args[])
{
int arr[] = { 1 , 3 , 5 , 4 , 7 };
int n = arr.length;
for ( int i = 0 ; i < 4 *M + 1 ; i++){
tree.add( new ArrayList<Integer>(
List.of(Integer.MIN_VALUE, 0 )
));
}
System.out.println(countLIS(arr, n));
}
}
public class comp implements Comparator<ArrayList<Integer>>{
public int compare(ArrayList<Integer> a, ArrayList<Integer> b)
{
if (a.get( 0 ).equals(b.get( 0 ))) {
return b.get( 1 ).compareTo(a.get( 1 ));
}
return a.get( 0 ).compareTo(b.get( 0 ));
}
}
|
C#
using System;
class SegmentTree {
private int [] tree;
public SegmentTree( int size)
{
int height = ( int )Math.Ceiling(Math.Log(size, 2));
int maxSize = 2 * ( int )Math.Pow(2, height) - 1;
tree = new int [maxSize];
}
public void BuildTree( int [] arr, int pos, int low,
int high)
{
if (low == high) {
tree[pos] = arr[low];
return ;
}
int mid = (low + high) / 2;
BuildTree(arr, 2 * pos + 1, low, mid);
BuildTree(arr, 2 * pos + 2, mid + 1, high);
tree[pos] = Math.Max(tree[2 * pos + 1],
tree[2 * pos + 2]);
}
public int Query( int pos, int low, int high, int start,
int end)
{
if (start <= low && end >= high) {
return tree[pos];
}
if (start > high || end < low) {
return int .MinValue;
}
int mid = (low + high) / 2;
int left = Query(2 * pos + 1, low, mid, start, end);
int right
= Query(2 * pos + 2, mid + 1, high, start, end);
return Math.Max(left, right);
}
}
class LIS {
public static int GetLISLength( int [] arr)
{
int n = arr.Length;
int [] sortedArr = new int [n];
Array.Copy(arr, sortedArr, n);
Array.Sort(sortedArr);
int [] indexMap = new int [n];
for ( int i = 0; i < n; i++) {
indexMap[Array.IndexOf(sortedArr, arr[i])] = i;
}
SegmentTree tree = new SegmentTree(n);
tree.BuildTree( new int [n], 0, 0, n - 1);
int [] dp = new int [n];
for ( int i = 0; i < n; i++) {
int prevMax = tree.Query(0, 0, n - 1, 0,
indexMap[i] - 1);
dp[i] = prevMax + 1;
tree.BuildTree(dp, 0, 0, n - 1);
}
int maxLIS = 0;
for ( int i = 0; i < n; i++) {
maxLIS = Math.Max(maxLIS, dp[i]);
}
return maxLIS;
}
}
class Program {
static void Main( string [] args)
{
int [] arr = { 1, 3, 5, 4, 7 };
Console.WriteLine(LIS.GetLISLength(arr));
}
}
|
Javascript
<script>
let M = 100000
let tree = new Array(4 * M + 1).fill(0).map(() => []);
function update_tree(start, end, update_idx, length_t, count_c, idx) {
if (start == end
&& start == update_idx) {
tree[idx][0]
= Math.max(tree[idx][0], length_t);
tree[idx][1] = count_c;
return ;
}
if (update_idx < start
|| end < update_idx) {
return ;
}
let mid = Math.floor((start + end) / 2);
update_tree(start, mid, update_idx,
length_t, count_c,
2 * idx);
update_tree(mid + 1, end, update_idx,
length_t, count_c,
2 * idx + 1);
if (tree[2 * idx][0]
== tree[2 * idx + 1][0]) {
tree[idx][0]
= tree[2 * idx][0];
tree[idx][1]
= tree[2 * idx][1]
+ tree[2 * idx + 1][1];
}
else if (tree[2 * idx][0]
> tree[2 * idx + 1][0]) {
tree[idx] = tree[2 * idx];
}
else {
tree[idx] = tree[2 * idx + 1];
}
}
function query(start, end, query_start, query_end, idx) {
if (query_start <= start
&& end <= query_end) {
return tree[idx];
}
let temp = [Number.MIN_SAFE_INTEGER, 0];
if (end < query_start
|| query_end < start) {
return temp;
}
let mid = Math.floor((start + end) / 2);
let left_child
= query(start, mid, query_start,
query_end, 2 * idx);
let right_child
= query(mid + 1, end, query_start,
query_end, 2 * idx + 1);
if (left_child[0] > right_child[0]) {
return left_child;
}
if (right_child[0] > left_child[0]) {
return right_child;
}
return [left_child[0],
left_child[1]
+ right_child[1]];
}
function comp(a, b) {
if (a[0] == b[0]) {
return a[1] > b[1];
}
return a[0] < b[0];
}
function countLIS(arr, n) {
let pair_array = new Array(n).fill(0).map(() => []);
for (let i = 0; i < n; i++) {
pair_array[i][0] = arr[i];
pair_array[i][1] = i;
}
pair_array.sort(comp);
for (let i = 0; i < n; i++) {
let update_idx = pair_array[i][1];
if (update_idx == 0) {
update_tree(0, n - 1, 0, 1, 1, 1);
continue ;
}
let temp = query(0, n - 1, 0, update_idx - 1, 1);
update_tree(0, n - 1, update_idx,
temp[0] + 1,
Math.max(1, temp[1]), 1);
}
let ans = query(0, n - 1, 0, n - 1, 1);
return ans[1];
}
let arr = [1, 3, 5, 4, 7];
let n = arr.length;
document.write(countLIS(arr, n));
</script>
|
Python3
import math
class SegmentTree:
def __init__( self , size):
height = math.ceil(math.log2(size))
maxSize = 2 * ( 2 * * height) - 1
self .tree = [ 0 ] * maxSize
def BuildTree( self , arr, pos, low, high):
if low = = high:
self .tree[pos] = arr[low]
return
mid = (low + high) / / 2
self .BuildTree(arr, 2 * pos + 1 , low, mid)
self .BuildTree(arr, 2 * pos + 2 , mid + 1 , high)
self .tree[pos] = max ( self .tree[ 2 * pos + 1 ], self .tree[ 2 * pos + 2 ])
def Query( self , pos, low, high, start, end):
if start < = low and end > = high:
return self .tree[pos]
if start > high or end < low:
return float ( '-inf' )
mid = (low + high) / / 2
left = self .Query( 2 * pos + 1 , low, mid, start, end)
right = self .Query( 2 * pos + 2 , mid + 1 , high, start, end)
return max (left, right)
class LIS:
@staticmethod
def GetLISLength(arr):
n = len (arr)
sortedArr = sorted (arr)
indexMap = [ 0 ] * n
for i in range (n):
indexMap[sortedArr.index(arr[i])] = i
tree = SegmentTree(n)
tree.BuildTree([ 0 ] * n, 0 , 0 , n - 1 )
dp = [ 0 ] * n
for i in range (n):
prevMax = tree.Query( 0 , 0 , n - 1 , 0 , indexMap[i] - 1 )
dp[i] = prevMax + 1
tree.BuildTree(dp, 0 , 0 , n - 1 )
maxLIS = 0
for i in range (n):
maxLIS = max (maxLIS, dp[i])
return maxLIS
arr = [ 1 , 3 , 5 , 4 , 7 ]
print (LIS.GetLISLength(arr))
|
Time Complexity: O(N*log N) Auxiliary Space: O(N)
Related Topic: Segment Tree
|