프로그래밍/PythonAdvanced

[파이썬] np.sum, 넘파이 배열에 조건문을 달면? (feat.밑바닥부터 시작하는 딥러닝)

자연대생 2023. 7. 16. 22:47
>>> np.sum([0.5, 1.5])
2.0
>>> np.sum([[0, 1], [0, 5]])
6
>>> np.sum([[0, 1], [0, 5]], axis=0)
array([0, 6])
>>> np.sum([[0, 1], [0, 5]], axis=1)
array([1, 5])

원래 사람들이 많이 쓰는 구조이다.

sum()의 axis 인수는 몇 번째 차원인지를 지정하는 값이다.

0번째 차원은 세로, 1번째 차원은 가로이다.

 

그런데 이런 건 어떨까?

import numpy as np
X = np.array([51, 55, 14, 19, 0, 4])

print(X > 15)	# bool 배열
# [ True  True False  True False False]

print(X[X>15])	# True에 해당하는 원소만
# [51 55 19]

 

bool 배열이 만들어진다.

인덱스에 조건문을 달면 True인 원소만 X에 저장되게 된다.

 

그럼 더 나아가서 이런 건 또 어떨까?

x, t = get_data()   # x: 입력값, t: 정답값
network = init_network()

accuracy_cnt = 0

for i in range(0, len(x), 100):
    x_batch = x[i:i+100]	# 
    y_batch = predict(network, x_batch) # 학습된 값들
    p = np.argmax(y_batch, axis=1)  # 학습값 중 최댓값
    accuracy_cnt += np.sum(p == t[i:i+100])
    
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))  # Accuracy:0.9352

get_data()는 말 그대로 데이터를 얻는 만들어낸 함수, 

init_network()는 신경망을 만드는 만들어낸 함수이므로 크게 신경 쓰지 않는다.

 

for문의 인수를 보면

0에서 len(x)까지 100 간격으로 증가하는 정수를 만든다. (len(x)=10000)

x_batch는 x[0:100], x[100:200], ... 이 될 것이다.

y_batch는 그냥 학습된 값이라고 생각하면 된다.

predict는 만들어낸 함수로, 학습에 쓰이는 함수라고 생각하면 된다.

p는 argmax로 학습값 중 최댓값을 뽑아내는 역할을 한다.

이제부터가 본론이다.

np.sum(p == t[i:i+100])

 

이게 대체 뭘까??

import numpy as np
y = np.array([1, 2, 1, 0])
t = np.array([1, 2, 0, 0])
print(y==t) # [ True  True False  True]
print(np.sum(y==t)) # 3

이걸 보면 알 수 있다.

(y==t)라는 조건문도 bool 배열을 만든다.

sum이 나와서 더 헷갈려지는데

여기서의 sum은 True를 1로, False를 0으로 둬서 모두 더한 값을 반환한다.

 

p는 학습값이고 t는 정답값이니

다 더해서 원소의 개수로 나눈 것이 정확도임을 쉽게 이해할 수 있을 것이다.

def accuracy(self, x, t):
    y = self.predict(x)
    y = np.argmax(y, axis=1)
    t = np.argmax(t, axis=1)
        
    accuracy = np.sum(y == t) / float(x.shape[0])
    return accuracy

이 정확도 함수도 np.sum(y==t)가 쓰였다.

똑같이 해석하면 된다.

 

import numpy as np
x = np.array([[1.0, -0.5], [-2.0, 3.0]])
print(x)
# [[ 1.  -0.5]
#  [-2.   3. ]]

mask = (x <= 0)
print(mask)
# [[False  True]
#  [ True False]]

위 예제는 x가 0 이하이면 True를 저장하는 넘파이 배열이다.