numpy 093-彻底读懂numpy的轴(Axis)

请听题

这其实是⼀个github上的练习项⽬,第93题。

https://github.com/rougier/numpy-100

给定⼀个8X3的数组A和⼀个2X2的数组B,从A中找出满⾜条件的⾏,条件是B中每⼀⾏都有至少有1个元素出现在A中这⼀⾏中?(不考虑B中每⾏元素顺序)(提⽰: np.where)

先给答案

import numpy as np
np.random.seed(24) #为了保证每次结果可以重复,设个种子
A = np.random.randint(0,5,(8,3))
B = np.random.randint(0,5,(2,2))
C = (A[:,:,np.newaxis, np.newaxis] == B)
rows = np.where(C.any((3,1)).all(1))[0]
print (rows)
## [1 5 7]

看不懂题目,可以找个实例看看

print (A,'\n\n',B)
## [[2 3 0]
##  [1 1 1]
##  [4 3 4]
##  [3 2 3]
##  [3 3 3]
##  [1 2 1]
##  [3 4 0]
##  [4 1 1]] 
## 
##  [[4 1]
##  [1 1]]

这个项⽬中凡是★★★以上的都很难。这个题是93题。

即便给了答案,⼤家理解起来依然⾮常难。这并不是⼤家能⼒不够,因为程序员们也普遍反映70以后的题基本觉得⾃⼰就是个傻⽠。

理解了这个题,就彻底学懂了numpy的轴(Axis)

代码解析

无需解读部分

代码前4行,不需要解释

import numpy as np
np.random.seed(24) #为了保证每次结果可以重复,设个种子
A = np.random.randint(0,5,(8,3))
B = np.random.randint(0,5,(2,2))

第5行解读

C = (A[:,:,np.newaxis, np.newaxis] == B)
print (A[:,:,np.newaxis, np.newaxis])
## [[[[2]]
## 
##   [[3]]
## 
##   [[0]]]
## 
## 
##  [[[1]]
## 
##   [[1]]
## 
##   [[1]]]
## 
## 
##  [[[4]]
## 
##   [[3]]
## 
##   [[4]]]
## 
## 
##  [[[3]]
## 
##   [[2]]
## 
##   [[3]]]
## 
## 
##  [[[3]]
## 
##   [[3]]
## 
##   [[3]]]
## 
## 
##  [[[1]]
## 
##   [[2]]
## 
##   [[1]]]
## 
## 
##  [[[3]]
## 
##   [[4]]
## 
##   [[0]]]
## 
## 
##  [[[4]]
## 
##   [[1]]
## 
##   [[1]]]]

升维操作

np.newaxis可以给numpy创建的数组升高维度,np.newaxis插在哪里就升在哪里。 比如可以试试A[np.newaxis,:]A[:,np.newaxis]有什么不同。

而第5行,还有另外一个写法,A[...,np.newaxis,np.newaxis]。它和A[:,:,np.newaxis, np.newaxis]是等同的。总体来说,…的意思就是,所有维度。这个…可以放在任意位置标示这段所有维度,如[np.newaxis,np.newaxis,...]或者[np.newaxis,...,np.newaxis]

你可以试试下面这个四维代码的语句

x = np.arange(1,81).reshape(4,5,2,2)
print (x)
print (x[...,1])
print(x[:,1])

如果只在后⾯加维度,前⾯的维度不变,⽐如开始的是3⾏2列,3⾏那在第⼀个维度(也就是Axis =0,轴序号是从0开始)始终都不会变,⽆论加多少维度,第⼀维度始终是3。

三维数组的轴示意图

⼀旦超过三维,就很难想象或理解,我们还可以⽤树形结构来表⽰多维关系,这种表达⽅式不需要结构或空间想象。有⼏维,就有⼏个轴(时刻不要忘记轴序号以0开始)。

三维数组的树形结构

数组比较

A[:,:,np.newaxis, np.newaxis] == B是将两个数组进行比较。数组的比较非常有意思,可以直接用==进行运算。

  • 比较是有前提的,即结构绝对相同
  • 若结构相同,同样位置的值,⼀⼀⽐较,形成⼀个新的结构相同的数组,数组⾥只有布尔类型。
  • 若结构不同,则直接返回false并报错。
#代码示例(结构相同)
A = np.array([[1,2,3],[4,5,6]])
B = np.array([[1,2,5],[4,4,5]])
A == B
## array([[ True,  True, False],
##        [ True, False, False]])
#代码示例(结构不同)
d2 = np.arange(20).reshape(4,5)
d3 = np.arange(20).reshape(5,4)
print (d2 == d3)
## False
## 
## <string>:1: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
A = np.array([10,7,2,0])
B = np.array([10,9,8,7,6,5,4,3,2,1])
print (A == B)
## False

如果两个数组结构不同呢?怎么比较才不会报错?可以升高维度。

⽐如两个结构不同的1维数组,我们把其中一个1维数组升维以后,变成2维,再和另⼀个1维数组⽐较

#代码示例(1维升2维以后)
C = A [...,np.newaxis]
print (C)
## [[10]
##  [ 7]
##  [ 2]
##  [ 0]]
print (C == B)
## [[ True False False False False False False False False False]
##  [False False False  True False False False False False False]
##  [False False False False False False False False  True False]
##  [False False False False False False False False False False]]

升了⼀个维度,相当于就可以让B中的每个数去和升维后的A(即C)去⼀个⼀个⽐较,并最终产⽣了⼀个2维的新数组(4⾏10列)。升⾼维度就是为了能够和不同形状的数组进⾏单个元素⽐较,⽽且升⾼的维度必须为原来的N倍,才可能进⾏⽐较,⽐如原来是2维,那就得变4维甚至6维才能和其比较。

有人可能会问,升维以后,那些没填满的值都是什么?我没有仔细查看资料,但我认为存在两个特性。

  • 继承,即下一维度都是这个值
  • 可变,即可随着比较的数组而变化

比如上面的例子,产⽣了⼀个2维的新数组(4⾏10列),因为要比较的数组B有10个数值,如果B只有5个数值或者3个数值呢?我们可以运行这个代码看看。

#代码示例(1维升2维以后再示例)
A = np.array([10,7,2,0])
B = np.array([10,9,2])
C = A [...,np.newaxis]
print (C == B)
## [[ True False False]
##  [False False False]
##  [False False  True]
##  [False False False]]

由于维度不高,所以我们还是可以用一个表格来表示

2维数组和1维数组比较

通过上面代码的运行结果和这个示意图,我们可以看到在坐标(0,0)和(2,2)是TRUE。

轴的折叠

想要学会np.where这个函数,必须先理解轴的折叠,这也是numpy比较难的知识点。

数组轴的个数,在python的世界中,轴的个数被称作秩,轴的个数与数组中括号“[”或者”]“的个数相同,轴指向多维数组的单个维度:

#数一数维数和[]的关系
A = np.random.randint(0,5,(2,3,4))
print (A)
## [[[1 2 3 0]
##   [2 3 1 1]
##   [3 4 1 2]]
## 
##  [[1 0 4 3]
##   [2 0 1 0]
##   [3 2 3 4]]]

官⽅⽂档的解释:在numpy进行sum或where等方法进行运算时,指定轴的⽅式可能会让来⾃其他语⾔的⽤户感到困惑。axis关键字指定将折叠的数组的维度,⽽不是将要返回的维度。因此,指定axis=0意味着第⼀个轴将折叠:对于⼆维数组,这意味着行不存在了,只存在列,每列中的值将被聚合。我们先随便用一维数组试试,同样的例子。

A = np.array([10,7,2,0])
B = np.array([10,9,8,7,6,5,4,3,2,1])
np.where (B)
## (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64),)
C = A [...,np.newaxis] == B
print (C)
## [[ True False False False False False False False False False]
##  [False False False  True False False False False False False]
##  [False False False False False False False False  True False]
##  [False False False False False False False False False False]]
print ('C location:',np.where (C))
 #输出数组中不为0的坐标
## C location: (array([0, 1, 2], dtype=int64), array([0, 3, 8], dtype=int64))
print ('Axis 1 combine:', C.any(1), np.where(C.any(1)))
## Axis 1 combine: [ True  True  True False] (array([0, 1, 2], dtype=int64),)
print ('Axis 0 combine:', C.any(0), np.where(C.any(0)))
## Axis 0 combine: [ True False False  True False False False False  True False] (array([0, 3, 8], dtype=int64),)

numpy.where有2种⽤法,

  1. np.where(condition, x, y),满⾜条件(condition),输出x,不满⾜输出y。
np.where(A > 5,1,-1)
## array([ 1,  1, -1, -1])
  1. np.where(condition)

只有条件 (condition),没有x和y,则输出满⾜条件 (即⾮0) 元素的轴 (等价于numpy.nonzero)。这⾥的轴以tuple(元组)的形式给出,通常原数组有多少维,输出的tuple中就包含⼏个数组,分别对应符合条件元素的各轴,即可推断坐标。这⾥numpy.where显然属于第2种⽤法。

np.where(C.any(1))意思是把第2个轴(即axis-1)折叠了,即把列折叠了,这样就没有列了,只剩行。从而输出的是已经把列折叠以后的行数。我们还可以理解为在这一行,只要有一个数不是0,那就输出该行的位置。

np.where(C.any(0))则是把行折叠了,输出的列数。理解同上。

上⾯的代码,也是轴的折叠,是⽤where实现,numpy很多命令都是可以实现轴的折叠的,⽐如我们可以⽤numpy.sum函数。2维结构表示,⽆⾮就和前⾯说的⼀样,就是⾏和列的操作。

A = np.arange(1,11).reshape(2,5)
print (A,"\n")
## [[ 1  2  3  4  5]
##  [ 6  7  8  9 10]]
sum0 = A.sum(0)
sum1 = A.sum(1)
print("0轴折叠相加\n",sum0)
## 0轴折叠相加
##  [ 7  9 11 13 15]
print("1轴折叠相加\n",sum1)
## 1轴折叠相加
##  [15 40]

回到我们的93题,我们先从外到内对np.where语句进行拆解。

np.random.seed(24)
A = np.random.randint(0,5,(8,3))
B = np.random.randint(0,5,(2,2))
C = (A[:,:,np.newaxis, np.newaxis] == B)
print (A,'\n\n',B)
## [[2 3 0]
##  [1 1 1]
##  [4 3 4]
##  [3 2 3]
##  [3 3 3]
##  [1 2 1]
##  [3 4 0]
##  [4 1 1]] 
## 
##  [[4 1]
##  [1 1]]
rows = np.where(C)
print (rows)
## (array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 5, 5, 5, 5, 5, 5, 6, 7, 7, 7, 7,
##        7, 7, 7], dtype=int64), array([0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 2, 0, 0, 0, 2, 2, 2, 1, 0, 1, 1, 1,
##        2, 2, 2], dtype=int64), array([0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1,
##        0, 1, 1], dtype=int64), array([1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1,
##        1, 0, 1], dtype=int64))

A = np.random.randint(0,5,(8,3))

8,3树形结构

B = np.random.randint(0,5,(2,2))

2,2树形结构

A[:,:,np.newaxis, np.newaxis]

升维以后,版面不够,我先把第0轴第2个元素提出来了,即Axis0[1],因为这个满足了题目的筛选条件。你知道A升维以后,再和B进行比较,数组编程了多少个元素吗?应该是8*3*2*2=96

第0轴第2个元素

下面我们以Axis0[1]为例,折叠轴了。看好了,首先折叠第4轴,也就是最下面的轴

np.where(C.any((3)))

第4轴折叠

np.where(C.any((3,1)))

第2轴折叠

其实下面的步骤我已经不用再解释了,你应该懂了。你可以自己画一画。 np.where(C.any((3,1)).all(1))

同样的,我们选取一个不满足要求的行,比如Axis0[2],当你看到下图的时候,你就立刻明白为它如何不满足要求了。

第0轴第2个元素