[프로그래머스] 2개 이하로 다른 비트

2 분 소요

문제

출처

함수 f는 음이 아닌 정수 x에 대하여, x보다 크고 x와 비트가 1개 또는 2개 다른 수 중 제일 작은 수를 계산한다. 음이 아닌 정수가 담긴 배열이 주어질 때, 각 원소에 f를 적용한 값을 배열에 담아 return하라.

1 <= numbers의 길이 <= 10^5
0 <= numbers의 모든 수 <= 10^15

접근

1차

  • 어떤 수 a와 b가 주어질 때, 비트가 다른 지점의 수는 xor 연산과 선형 탐색으로 구할 수 있다.
  • y = x + 1에서 시작하여 y값을 1씩 증가시켜가며 f(y)가 주어진 조건을 만족하는지 확인한다. 가장 먼저 발견된 y가 조건을 만족하는 가장 작은 수가 된다.

구현하면 다음과 같다.

def bit_dif(a, b):
    ret = 0
    tmp = a^b
    while tmp > 0:
        if tmp % 2 == 1:
            ret += 1
        tmp //= 2
    return ret
        
def brute_force(x):
    n = x + 1
    while True:
        res = bit_dif(x, n)
        if 0 < res < 3:
            return n
        n += 1

def solution(numbers):
    return [brute_force(n) for n in numbers]

하지만 이 방법은 시간복잡도가 지나치게 크다. 대략 O(배열 길이 * 비트 길이 * 수의 크기) 이므로 O(mnlog(n))가 된다.

나쁜 케이스를 예로 들자면 0111 1111 1111과 같은 수를 생각해볼 수 있다. 이 수로 f를 계산해보면 1을 더했을 때 1000 0000 0000이 되어서 원래 수와 1 또는 2비트 차이가 날 때까지 한참을 더해야 하는 것을 알 수 있으며, 실제로 제출해도 시간 초과를 받는다.

2차

  • brute_force로 구한 답을 1부터 100까지 뿌려서 확인해보자.

x, f(x), f(x) - x 순으로 출력했는데, f(x) - x가 2의 배수로 나타나는 패턴이 보인다.

0 1 1
1 2 1
2 3 1
3 5 2
4 5 1
5 6 1
6 7 1
7 11 4
8 9 1
9 10 1
10 11 1
11 13 2
12 13 1
13 14 1
14 15 1
15 23 8

2진수로 바꿔서 살펴보면 보다 분명해진다.

11: 1011
13: 1101

15:  1111
23: 10111

가장 낮은 자리부터 연속된 1을 세서, 그중 가장 높은 자리의 1을 x에 더해주면 f(x)가 된다.
왜 그럴까?

x의 2진수 표현의 낮은 자리의 패턴을 생각해보자.

0으로 끝난다면 f(x)는 자명하게 x + 1이다.

0 -> 1
1개 비트만 변하면서 0보다 크다.

01로 끝나도 마찬가지로 f(x)는 x + 1이다.

01 -> 10 
2개 비트만 변하면서 01보다 크다

011로 끝나면? 이야기가 달라진다.

011 -> 100
1을 더했을 때 3개 비트가 변한다.

그런데 011에는 01이 포함된다. 즉 10을 더하면 101이 되어 2개 비트만 변한다.
0111도 마찬가지로 01을 포함하므로, 100을 더하면 1011이 되어 2개 비트만 변한다.

이렇게 구한 값은 x보다 크면서 1개 또는 2개 비트만 다른 수 중 가장 작은 값일까?
0111로 끝나는 값보다 큰 값은 당연히 1xxx꼴이 되어 x와 1비트가 이미 다르다. 남은 자리들 중 1비트만 변해야 하므로, 1011, 1101, 1110 셋 중 하나가 답이다. 이중 가장 작은 것이 바로 가장 높은 자리에 1을 더한 1011이 된다.

따라서 가장 낮은 자리부터 연속된 1을 세서, 그중 가장 높은 자리의 1을 x에 더해주면 f(x)가 된다.

구현

def f(x):
    tmp_x = x
    cnt = 0
    while tmp_x > 0:
        if tmp_x % 2 == 0:
            break
        tmp_x //= 2
        cnt += 1
    if cnt == 0:
        return x + 1
    return x + (1 << cnt - 1)
    
def solution(numbers):
    return [f(n) for n in numbers]

댓글남기기