본문 바로가기

2020-2021 동계 모각코

[붕어빵 꼬리먼저 팀] 3회차 - 학습 마무리

2021/01/19 동계 모각코 3회차

오늘 공부하고자 했던 머지 소트 트리에 대한 이해와 문제 풀이를 완료하였습니다.

 

원래 계획했던 문제는 수열과 쿼리 1 (www.acmicpc.net/problem/13537) 였습니다.

그러나 해당 문제는 오프라인 쿼리를 활용하여 머지 소트 트리 없이도 해결이 가능한 문제였기에,

입력에 이전 쿼리의 정답을 요구하여 온라인 쿼리를 강제한 수열과 쿼리 3 (www.acmicpc.net/problem/13544) 를 풀기로 하였습니다.

 

해당 문제는 온라인 쿼리를 강제하여 오프라인 쿼리를 사용한 세그먼트 트리 풀이를 막았기 때문에 머지 소트 트리로만 풀어야 합니다.

 


[머지 소트 트리] www.acmicpc.net/problem/13544

머지 소트 트리를 사용하여 특정 구간에서 K보다 큰 원소의 개수를 구하는 문제입니다.

 

머지 소트 트리는 $log_2N$의 재귀를 사용하는 머지 소트와 세그먼트 트리를 결합한 트리입니다.

특정 구간의 구간 합, 최솟값, 최댓값 등을 저장하던 세그먼트 트리에서 정렬된 배열을 넣은 것이라고 생각할 수 있습니다.

그런데 위에서 서술했듯 세그먼트 트리의 갱신 과정이 너무나 머지 소트와 결합하기 좋은 형태였기에 머지 소트로 정렬한 트리 형태를 갖게 된 것입니다.

 

머지 소트 트리를 구현하고 left와 right를 조정하면서 left와 right가 원하는 범위 안에 들었다면 해당 노드에 저장했던 정렬된 배열에서 K보다 큰 원소의 개수를 찾아내면 됩니다.

이때, 배열이 정렬되었으므로 이분 탐색을 사용하고 얻어낸 값들을 전부 더하여 return하면 답을 구할 수 있습니다.

 

 

이 문제를 Java로 풀이하였는데, C++로 풀이를 한다면 훨씬 더 간단합니다.

문제를 풀이할 때, upper_bound와 merge를 사용해야 하지만 C++과 다르게 자바는 이를 지원하지 않기 때문입니다.

 

그러나 직접 구현하면서 merge에서 삼항 연산자를 깔끔하게 사용한 것 같아 나름 뿌듯했습니다.

 

 

[후기]

 

상당히 단순해 보이지만 생소한 개념에 대한 문제 풀이는 것은 늘 어려운 과정입니다.

해당 문제를 처음 제출했을 때 시간 초과가 발생했는데, 범위를 조정하면서 찾은 모든 배열들을 merge하는 말도 안 되게 비효율적인 과정을 거치고 통합되어 나온 배열에 이분 탐색을 하였습니다.

 

정말 말도 안 되는 풀이지만, 제출하고 나서 20분 이후까지도 무엇이 잘못됐었는지 알아내는 것이 힘들었습니다.

처음 풀이하는 개념이라 무엇을 잘못 구현했는지, 그래서 무한 루프가 발생했는지, 혹은 그냥 비효율적으로 구현했는지 확인하는 것도 어렵기 때문입니다.

또한 모든 주의가 새로운 개념에 대한 내용으로 끌어져 있기 때문에 이런 말도 안되는 사고를 갖게 될 가능성이 더 커지는 것 같습니다.

 

여러 유형에 대해 충분히 익숙한 상태를 갖는다는 것이 얼마나 중요한지 다시금 깨달았습니다.

 

 

[소스 코드] / Java

 

import java.io.*;
import java.util.*;

public class Main {
    private static int[] arr;
    private static int[][] tree;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));

        int N = Integer.parseInt(br.readLine());
        arr = new int[N + 1]; tree = new int[4 * N][];

        StringTokenizer st = new StringTokenizer(br.readLine(), " ");
        for (int i = 1; i < N + 1; i++) {
            arr[i] = Integer.parseInt(st.nextToken());
        }

        init(1, N, 1);

        int M = Integer.parseInt(br.readLine());
        int last_ans = 0;

        StringBuilder sb = new StringBuilder();
        while (M-- > 0) {
            st = new StringTokenizer(br.readLine(), " ");
            int ru = Integer.parseInt(st.nextToken()), rv = Integer.parseInt(st.nextToken()), rk = Integer.parseInt(st.nextToken());
            int u = ru ^ last_ans, v = rv ^ last_ans, k = rk ^ last_ans;

            last_ans = getAns(1, N, 1, u, v, k);
            sb.append(last_ans).append('\n');
        }

        System.out.println(sb.toString());
        br.close();
    }

    private static int getAns(int left, int right, int node, int u, int v, int k) {
        if (right < u || v < left) return 0;
        else if (u <= left && right <= v) return binarySearch(tree[node], k);
        int mid = (left + right) / 2;
        return getAns(left, mid, node * 2, u, v, k) + getAns(mid + 1, right, node * 2 + 1, u, v, k);
    }

    private static int[] init(int left, int right, int node) {
        if (left == right) {
            tree[node] = new int[1];
            tree[node][0] = arr[left];
        } else {
            int mid = (left + right) / 2;
            int[] l = init(left, mid, node * 2), r = init(mid + 1, right, node * 2 + 1);
            tree[node] = merge(l, r);
        }
        return tree[node];
    }

    private static int[] merge(int[] l, int[] r) {
        int lsize = l.length, rsize = r.length;
        int[] temp = new int[lsize + rsize];

        int i = 0, j = 0;
        int idx = 0;
        while (i < lsize && j < rsize) {
            temp[idx++] = l[i] < r[j] ? l[i++] : r[j++];
        }

        for (; i < lsize; i++) {
            temp[idx++] = l[i];
        }
        for (; j < rsize; j++) {
            temp[idx++] = r[j];
        }

        return temp;
    }

    private static int binarySearch(int[] branch, int k) {
        int size = branch.length;
        int left = 0, right = branch.length - 1;
        int result = size;

        while (left <= right) {
            int mid = (left + right) / 2;
            if (branch[mid] > k) {
                result = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }

        return size - result;
    }
}