파이썬 API 둘러보기

numpy.where()

조건을 만족하는 원소의 인덱스를 뽑아 준다.

리턴값이 튜플인데, 1차원에서는 쉽다.

>>> arr = np.array(range(10))

>>> arr
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

>>> np.where(arr % 2 == 0)
(array([0, 2, 4, 6, 8], dtype=int32),)

>>> arr[np.where(arr % 2 == 0)[0]] = 0

>>> arr
array([0, 1, 0, 3, 0, 5, 0, 7, 0, 9])

 

 2차원 이상은...

예를 들어 np.where(arr == 0)의 결과가

(array([1, 2], dtype=int32), array([3, 4], dtype=int32))

이면 이건 arr[1,3]과 arr[2,4] 를 가리킨다.

 

이 변환 과정이 좀 어려웠는데, 덕분에 문법 공부가 좀 됐다.

아래 코드에서 이 두 줄이 포인트다.

>>> where = np.where(arr >= 10)

>>> indices = tuple(zip(*where))

>>> arr = np.arange(12).reshape((3,4))

>>> arr
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

>>> np.where(arr >= 10)
(array([2, 2], dtype=int32), array([2, 3], dtype=int32))

>>> where = np.where(arr >= 10)

>>> indices = tuple(zip(*where))

>>> indices
((2, 2), (2, 3))

>>> arr[indices] = -1

>>> arr
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, -1, -1]])

댓글

댓글 본문
작성자
비밀번호
버전 관리
장과장02
현재 버전
선택 버전
graphittie 자세히 보기