2025. 10. 20. 15:02 Baekjoon(Python)/단계별로 풀어보기(Python)

Baekjoon(Python)/단계별로 풀어보기(Python)

[백준][19단계 재귀] 24060번 /알고리즘 수업 - 병합 정렬 1 (파이썬/Python)

junslee 2025. 10. 20. 15:02

1. 문제 설명

2. 코드

import sys
input = sys.stdin.readline

N, K = map(int, input().split())
A = list(map(int, input().split()))

tmp = [0] * N  # 전역 임시 배열 재사용
cnt = 0        # 지금까지의 저장(=A에의 대입) 횟수
answer = -1    # K번째 저장되는 수 (기본값 -1)

def merge_sort(l, r):
    if l >= r:
        return
    m = (l + r) // 2
    merge_sort(l, m)
    merge_sort(m + 1, r)
    merge(l, m, r)

def merge(l, m, r):
    # 1) l..m, m+1..r 두 구간을 tmp에 병합(오름차순)
    i, j, t = l, m + 1, l
    while i <= m and j <= r:
        if A[i] <= A[j]:
            tmp[t] = A[i]
            i += 1
        else:
            tmp[t] = A[j]
            j += 1
        t += 1
    while i <= m:
        tmp[t] = A[i]
        i += 1
        t += 1
    while j <= r:
        tmp[t] = A[j]
        j += 1
        t += 1

    # 2) tmp[l..r]을 A[l..r]에 "저장" (여기가 저장 카운트 지점)
    global cnt, answer
    for i in range(l, r + 1):
        A[i] = tmp[i]
        cnt += 1
        if cnt == K:
            answer = A[i]

merge_sort(0, N - 1)
print(answer)

3. 풀이 과정

이 문제를 해결하기 위해서는 합병 정렬(merge sort) 구현 과정을 이해해야 한다.

'분할 - 정복 - 결합' 과정으로 진행되고 있으며, 코드를 보면서 확인하면 다음과 같다.

def merge_sort(l, r):
    if l >= r:
        return
    m = (l + r) // 2
    merge_sort(l, m)
    merge_sort(m + 1, r)
    merge(l, m, r)

먼저, merge()를 진행할 값들을 찾아야 한다.

예제 1을 가지고 풀이를 진행해 보겠다. (N,K = 5,7) (A = [4,5,1,3,2])

merge_sort(0,4)라면 합병 정렬은 다음과 같이 된다.

merge_sort(0,4)
 ├─ merge_sort(0,2)
 │   ├─ merge_sort(0,1)
 │   │   ├─ merge_sort(0,0)  # 종료
 │   │   └─ merge_sort(1,1)  # 종료
 │   │   └─ merge(0,0,1)
 │   └─ merge_sort(2,2)      # 종료
 │   └─ merge(0,1,2)
 └─ merge_sort(3,4)
     ├─ merge_sort(3,3)      # 종료
     └─ merge_sort(4,4)      # 종료
     └─ merge(3,3,4)
 └─ merge(0,2,4)

이제 merge(0,0,1), merge(0,1,2), merge(3,3,4), merge(0,2,4)를 진행하면 된다.

merge(0,0,1)을 예시로 진행해 보겠다.

def merge(l, m, r):
    # 1) l..m, m+1..r 두 구간을 tmp에 병합(오름차순)
    i, j, t = l, m + 1, l
    while i <= m and j <= r:
        if A[i] <= A[j]:
            tmp[t] = A[i]
            i += 1
        else:
            tmp[t] = A[j]
            j += 1
        t += 1
    while i <= m:
        tmp[t] = A[i]
        i += 1
        t += 1
    while j <= r:
        tmp[t] = A[j]
        j += 1
        t += 1

l,m,r = 0,0,1이고, i,j,t = 0,1,0 이다.

  • i는 왼쪽 구간 A[l...m]을 가리키는 포인터이다.
  • j는 오른쪽 구간 A[m+1...r]을 가리키는 포인터이다.
  • t는 병합 결과를 적어 넣을 임시배열 tmp의 인덱스이다.

여기서 i와 m, j와 r을 비교하면 된다. 이 조건들은 아직 왼쪽 구간과 오른쪽 구간에 원소가 남아 있는 경우를 의미한다.

i,m = 0,0이므로, 루프에 진입해서 tmp[0] = 4를 얻는다. (i,t = 1,1)

j,r = 1,1이므로, 루프에 진입해서 tmp[1] = 5를 얻는다. (j,t = 2,2)

tmp = [4,5,?,?,?] 상태가 된다. 

    global cnt, answer
    for i in range(l, r + 1):
        A[i] = tmp[i]
        cnt += 1
        if cnt == K:
            answer = A[i]

merge_sort(0, N - 1)
print(answer)

global cnt, answer는 cnt와 answer 변수가 전역 변수임을 선언하는 코드이다.

이후, 병합 결과를 A에 저장한다.

A[0] = tmp[0] = 4

A[1] = tmp[1] = 5로 두번 저장하게 된다.

A에 값에 변화는 없지만 merge(0,0,1)로 2번의 저장이 발생하였다.

merge(0,1,2), merge(3,3,4), merge(0,2,4)도 마찬가지로 진행하면

3번,2번,5번으로 총 12회 저장이 발생한다.

원하는 출력값은 K번째이자 7번째를 원하므로 cnt=7일 때의 값 3이 나오면 answer에 저장해서 출력한다.

반응형