4

动态开点线段树说明 - Grey Zeng

 1 year ago
source link: https://www.cnblogs.com/greyzeng/p/17011344.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

动态开点线段树说明

作者:Grey

原文地址:

博客园:动态开点线段树说明

CSDN:动态开点线段树说明

说明#

针对普通线段树,参考使用线段树解决数组任意区间元素修改问题

在普通线段树中,线段树在预处理的时候,需要申请 4 倍大小的数组空间来存放划分的区域,

而本文介绍的动态开点线段树,它和普通线段树的区别是,动态开点线段树不需要像普通线段树那样提前申请 4 倍大小的数据空间来存放划分区域,等到实际使用的时候,再来申请。

先讲一种比较简单的动态开点线段树,这种线段树只支持单点的更新和查询。

即支持如下两个方法

void add(i, v);

该方法表示在 i 上的值加上 v;

int query(int s, int e)

该方法用于获取 s 到 e 区间内的累加和信息。

该线段树只需要定义一个节点数据结构即可

  public static class Node {
    public int sum;
    public Node left;
    public Node right;
  }

其中 sum 表示 Node 所在区间的累加和,left 表示节点左孩子信息,right 表示节点右孩子信息。

线段树初始化过程也只需要

  public static class DynamicSegmentTree {
    public Node root;
    public int size;

    public DynamicSegmentTree(int max) {
      root = new Node();
      size = max;
    }
  }

size 表示线段树支持的范围,这个范围从线段树一开始初始化的时候设定好(编号1 到 编号size就是区间范围)。和普通线段树不一样的地方在于,节点只建立了 root 节点,未初始化所有区间。

接下来看add方法,

    public void add(int i, int v) {
      add(root, 1, size, i, v);
    }

这个方法调用了线段树内部的私有add方法,

    // c-> cur 当前节点!表达的范围 l~r
    // i位置的数,增加v
    // 潜台词!i一定在l~r范围上!
    private void add(Node c, int l, int r, int i, int v) {
      if (l == r) {
        c.sum += v;
      } else { // l~r 还可以划分
        int mid = (l + r) / 2;
        if (i <= mid) { // l ~ mid
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { // mid + 1 ~ r
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);
      }
    }

这个add方法的几个参数分别代表

c : 表示 add 操作的区间代表节点是多少

l...r 表示任务区间,由于初始化 size,所以在调用公开的 add 方法时候,l = 1, r = size,表示在初始化区间范围内操作。

i:表示要操作的位置

v: 表示要增加的值

整个 add 私有方法逻辑也比较简单,核心代码

        // i 在节点左边
        if (i <= mid) { 
            // 如果节点的左树为空,则建立新节点
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { 
            // i 在节点右边
            // 如果节点右树为空,则建立新节点
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        // 最后当前节点要汇聚左右树的结果,之所以要判空是因为左右树可能不需要都建立出来
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);

查询方法的逻辑也比较简单

    public int query(int s, int e) {
      return query(root, 1, size, s, e);
    }

调用了内部的一个私有 query 方法,

    private int query(Node c, int l, int r, int s, int e) {
      if (c == null) {
        return 0;
      }
      if (s <= l && r <= e) { 
        return c.sum;
      }
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }
    }
  }

这个私有方法的几个参数说明如下

c:表示要操作的线段树的代表节点是什么;

l...r 是划分的区间范围

s...e 是任务的区间范围

核心逻辑如下

// 如果任务的区间已经包含了划分的区间,直接返回结果
      if (s <= l && r <= e) { 
        return c.sum;
      }
      // 否则,去左右区间拿累加和
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        // 整合成自己的累加和返回
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }

整个支持单点更新的动态线段树的完整代码如下(含对数器代码)

// 只支持单点增加 + 范围查询的动态开点线段树(累加和)
public class Code01_DynamicSegmentTree {

  public static class Node {
    public int sum;
    public Node left;
    public Node right;
  }

  // arr[0] -> 1
  // 线段树,从1开始下标!
  public static class DynamicSegmentTree {
    public Node root;
    public int size;

    public DynamicSegmentTree(int max) {
      root = new Node();
      size = max;
    }

    // 下标i这个位置的数,增加v
    public void add(int i, int v) {
      add(root, 1, size, i, v);
    }

    // c-> cur 当前节点!表达的范围 l~r
    // i位置的数,增加v
    // 潜台词!i一定在l~r范围上!
    private void add(Node c, int l, int r, int i, int v) {
      if (l == r) {
        c.sum += v;
      } else { // l~r 还可以划分
        int mid = (l + r) / 2;
        if (i <= mid) { // l ~ mid
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { // mid + 1 ~ r
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);
      }
    }

    // s~e范围的累加和
    public int query(int s, int e) {
      return query(root, 1, size, s, e);
    }

    // 当前节点c,表达的范围l~r
    // 收到了一个任务,s~e这个任务!
    // s~e这个任务,影响了多少l~r范围的数,把答案返回!
    private int query(Node c, int l, int r, int s, int e) {
      if (c == null) {
        return 0;
      }
      if (s <= l && r <= e) {
        return c.sum;
      }
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }
    }
  }

  public static class Right {
    public int[] arr;

    public Right(int size) {
      arr = new int[size + 1];
    }

    public void add(int i, int v) {
      arr[i] += v;
    }

    public int query(int s, int e) {
      int sum = 0;
      for (int i = s; i <= e; i++) {
        sum += arr[i];
      }
      return sum;
    }
  }

  public static void main(String[] args) {
    int size = 10000;
    int testTime = 50000;
    int value = 500;
    DynamicSegmentTree dst = new DynamicSegmentTree(size);
    Right right = new Right(size);
    System.out.println("测试开始");
    for (int k = 0; k < testTime; k++) {
      if (Math.random() < 0.5) {
        int i = (int) (Math.random() * size) + 1;
        int v = (int) (Math.random() * value);
        dst.add(i, v);
        right.add(i, v);
      } else {
        int a = (int) (Math.random() * size) + 1;
        int b = (int) (Math.random() * size) + 1;
        int s = Math.min(a, b);
        int e = Math.max(a, b);
        int ans1 = dst.query(s, e);
        int ans2 = right.query(s, e);
        if (ans1 != ans2) {
          System.out.println("出错了!");
          System.out.println(ans1);
          System.out.println(ans2);
        }
      }
    }
    System.out.println("测试结束");
  }
}

接下来看一个使用动态开点线段树来解决的一个问题

即:LeetCode 315. Count of Smaller Numbers After Self

注:本题可以用归并排序,树状数组,有序表来解,也可以用动态开点线段树来解。

主要思路如下

以如下数组为例来说明

nums = {5,8,7,4,2,9}

首先,初始化一个 List,这个 List 用于存放每个位置的右侧比其小的数有几个,List 的大小和原始数组一样

List<Integer> ans = new ArrayList<>(nums.length);

ans 在初始化的时候,均设置为 0 ,表示,所有位置都还没计算过。

ans = [0,0,0,0,0,0]

接下来对原始数组进行排序(注意:排序的时候,不能只使用值来排序,要带上这个值所在的位置,这样排序后才不会丢失该值在原始数组中的位置信息)

    int[][] arr = new int[n][];
    for (int i = 0; i < n; i++) {
        // 要记录值,也要记录位置,防止排序后找不到值对应的位置在哪里
      arr[i] = new int[] {nums[i], i};
    }
    // 排序按值排序
    Arrays.sort(arr, Comparator.comparingInt(a -> a[0]));

排序后,arr 按如下顺序组织

{值:2,原始位置:4}
{值:4,原始位置:3}
{值:5,原始位置:0}
{值:7,原始位置:2}
{值:8,原始位置:1}
{值:9,原始位置:5}

接下来初始化开点线段树,线段树的size就是原始数组的大小,且每个位置都是0,

按顺序遍历这个 arr 数组,最小值 2 被取出,其原始位置是 4,且 4 号位置右侧没有比自己更小的数,接下来在开点线段树中把把 4 号位置的值加1,表示 4 号位置被处理过了,在线段树中查4号位置以后并没有任何标记记录,说明没有比这个数更小的数了,直接设置4号位置的ans值为0

ans = [0,0,0,0,0,0]
seg = [0,0,0,0,1,0]

接下来是 3 号位置的4,在线段树中查到,有一个比它小的,直接设置到 ans 中,然后在线段树中把 3 号位置也标记为 1,说明处理过,

ans = [0,0,0,1,0,0]
seg = [0,0,0,1,1,0]

接下来是0号位置的5, 在线段树中,查到右侧有两个标记过的,说明有两个比它小的数,直接在 ans 中把 0 号位置设置为 2, 然后在线段树中把 0 号位置标记为 1 ,说明处理过,此时

ans = [2,0,0,1,0,0]
seg = [1,0,0,1,1,0]

接下来是 2 号位置的 7, 在线段树中,查到右侧有两个标记过的,说明有两个比它小的数,直接在 ans 中把 2 号位置设置为 2, 然后在线段树中把 2 号位置标记为 1 ,说明处理过,此时

ans = [2,0,2,1,0,0]
seg = [1,0,1,1,1,0]

接下来是 1 号位置的 8, 在线段树中,查到右侧有三个标记过的,说明有三个比它小的数,直接在 ans 中把 1 号位置设置为 3, 然后在线段树中把 1 号位置标记为 1 ,说明处理过,此时

ans = [2,3,2,1,0,0]
seg = [1,1,1,1,1,0]

接下来是 5 号位置的 9, 在线段树中,查到右侧没有标记过的,说明没有比它小的数,直接在 ans 中把 5 号位置设置为 0, 然后在线段树中把 5 号位置标记为 1 ,说明处理过,此时

ans = [2,3,2,1,0,0]
seg = [1,1,1,1,1,1]

以上就是整个流程。

核心代码如下

  public static List<Integer> countSmaller(int[] nums) {
    if (nums == null || nums.length == 0) {
      return new ArrayList<>();
    }
    int n = nums.length;
    List<Integer> ans = new ArrayList<>(n);
    for (int i = 0; i < n; i++) {
      ans.add(0);
    }
    int[][] arr = new int[n][];
    for (int i = 0; i < n; i++) {
        // 要记录值,也要记录位置,防止排序后找不到值对应的位置在哪里
      arr[i] = new int[] {nums[i], i};
    }
    Arrays.sort(arr, Comparator.comparingInt(a -> a[0]));
    DynamicSegmentTree dst = new DynamicSegmentTree(n);
    for (int[] num : arr) {
      ans.set(num[1], dst.query(num[1] + 1, n));
      dst.add(num[1] + 1, 1);
    }
    return ans;
  }

其中 DynamicSegmentTree 结构就是前面提到的动态开点线段树的实现。

更多#

算法和数据结构笔记


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK