알고리즘

Binary Indexed Tree

fenec_fox 2024. 5. 24. 09:17

Binary Indexed Tree(또는 fenwick tree, 이하 BIT)는 Segment Tree와 비슷하게 구간에 대한 정보를 저장할 수 있는 자료구조입니다. 기본적으로 아래의 그림과 같은 1-based 구조를 가지며, 구간합 등을 알아내는데 사용됩니다.

https://www.topcoder.com/community/competitive-programming/tutorials/binary-indexed-trees

BIT는 위의 그림처럼 세그먼트트리보다 조금 더 단순화된 구조를 가지지만, 구간에 대한 정보를 [1,X]으로 알아낼 수 있다는 점이 다릅니다.

위 그림의 2번지에 그려진 1번지~2번지의 박스는 [1,2]의 구간의 정보를 [2,2]에 가지고 있다는 것을 의미합니다. 마찬가지로 12번지는 [9,12]의 구간의 정보를 가지고 있는 것입니다. 이를 적절히 이용하여 원하는 구간의 정보들을 구하는 것입니다.

BIT로 구간의 합을 알아보겠습니다.
트리의 구성은 먼저 최대 구간으로 배열을 잡습니다. 1-based 트리구조이므로 유의하여 배열을 선언합니다. 트리에 자료를 삽입할 때는, 트리의 원하는 번지 X에 원하는 값 Y를 누적하여줍니다. 그리고 X를 X+LSB(X)로 옮겨, 옮겨진 위치에도 Y를 누적시켜줍니다. 이 방법을 X가 배열의 범위를 벗어나지 않을 때까지 반복합니다. (*LSB=최하위비트)

<1의 값은 1,2,4,8,16에 모두 더해짐>

이렇게 완성된 트리에서 [1, X] 구간의 합을 알아내는 과정은 다음과 같습니다. 먼저 X를 11이라고 할 때, 직관적으로 우리는 트리의 8번지와 10번지와 11번지를 모두 더하면 1~11번지의 합을 구할 수 있음을 알 수 있습니다.

이 또한 LSB를 이용하여 해결할 수 있습니다.
X = 1일 때, 11 - LSB(11) = 10이 됩니다. 또 10 - LSB(10) = 8, 8 - LSB(8) = 0입니다.
결국 [1,X]의 누적합을 알기 위해선, X의 트리 값을 더한 후 X를 X-LSB(X)로 옮겨가면서 X가 0이 될 때까지 반복하면 누적합을 얻을 수 있습니다.

BIT는 세그먼트 트리에 비해 간단한 구현과 적은 메모리로 구현이 가능하지만, 세그먼트 트리는 다양한 목적으로 사용이 가능하므로 필요에 따라 사용해야 합니다.

다음은 누적된 값을 알아내는 BIT 코드입니다.

더보기
입력 출력
첫 번째 줄에 배열 D의 원소의 수 N,
두 번째 줄에 D의 원소들이 주어지며
세 번째 줄에 정수 M이 주어진다.
네 번째 줄 부터 M + 3 번째 줄 까지
정수 Xi가 주어진다.
각각의 정수 Xi에 대해 D[1, Xi]의 합을 출력한다.
7
1 -5 6 7 2 3 10
3
5
1
7
11
1
24
import java.util.ArrayList;
import java.util.Scanner;

class BIT {
    ArrayList<Integer> T;
    int S;

    public BIT(int n) {
        this.S = n;
        T = new ArrayList<Integer>();
        for (int i = 0; i <= n; i++) T.add(0);
    }

    void Update(int p, int v) {
        while (p <= S) {
            T.set(p, T.get(p) + v);
            p += p & (-p);
        }
    }

    int Sum(int p) {
        int ret = 0;
        while (p > 0) {
            ret += T.get(p);
            p -= p & (-p);
        }
        return ret;
    }
}

public class Main {
    int v[];

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt();

        BIT bit = new BIT(n);

        for (int i = 1; i <= n; i++) {
            int x = sc.nextInt();
            bit.Update(i, x);
        }

        int m = sc.nextInt();

        for (int i = 0; i < m; i++) {
            int x = sc.nextInt();
            System.out.println(bit.Sum(x));
        }
    }
}