numpy 093-彻底读懂numpy的轴(Axis)
请听题
这其实是⼀个github上的练习项⽬,第93题。
https://github.com/rougier/numpy-100
给定⼀个8X3的数组A和⼀个2X2的数组B,从A中找出满⾜条件的⾏,条件是B中每⼀⾏都有至少有1个元素出现在A中这⼀⾏中?(不考虑B中每⾏元素顺序)(提⽰: np.where)
先给答案
import numpy as np
np.random.seed(24) #为了保证每次结果可以重复,设个种子
A = np.random.randint(0,5,(8,3))
B = np.random.randint(0,5,(2,2))
C = (A[:,:,np.newaxis, np.newaxis] == B)
rows = np.where(C.any((3,1)).all(1))[0]
print (rows)## [1 5 7]
看不懂题目,可以找个实例看看
print (A,'\n\n',B)## [[2 3 0]
## [1 1 1]
## [4 3 4]
## [3 2 3]
## [3 3 3]
## [1 2 1]
## [3 4 0]
## [4 1 1]]
##
## [[4 1]
## [1 1]]
这个项⽬中凡是★★★以上的都很难。这个题是93题。
即便给了答案,⼤家理解起来依然⾮常难。这并不是⼤家能⼒不够,因为程序员们也普遍反映70以后的题基本觉得⾃⼰就是个傻⽠。
理解了这个题,就彻底学懂了numpy的轴(Axis)
代码解析
无需解读部分
代码前4行,不需要解释
import numpy as np
np.random.seed(24) #为了保证每次结果可以重复,设个种子
A = np.random.randint(0,5,(8,3))
B = np.random.randint(0,5,(2,2))第5行解读
C = (A[:,:,np.newaxis, np.newaxis] == B)
print (A[:,:,np.newaxis, np.newaxis])## [[[[2]]
##
## [[3]]
##
## [[0]]]
##
##
## [[[1]]
##
## [[1]]
##
## [[1]]]
##
##
## [[[4]]
##
## [[3]]
##
## [[4]]]
##
##
## [[[3]]
##
## [[2]]
##
## [[3]]]
##
##
## [[[3]]
##
## [[3]]
##
## [[3]]]
##
##
## [[[1]]
##
## [[2]]
##
## [[1]]]
##
##
## [[[3]]
##
## [[4]]
##
## [[0]]]
##
##
## [[[4]]
##
## [[1]]
##
## [[1]]]]
升维操作
np.newaxis可以给numpy创建的数组升高维度,np.newaxis插在哪里就升在哪里。 比如可以试试A[np.newaxis,:]和A[:,np.newaxis]有什么不同。
而第5行,还有另外一个写法,A[...,np.newaxis,np.newaxis]。它和A[:,:,np.newaxis, np.newaxis]是等同的。总体来说,…的意思就是,所有维度。这个…可以放在任意位置标示这段所有维度,如[np.newaxis,np.newaxis,...]或者[np.newaxis,...,np.newaxis]
你可以试试下面这个四维代码的语句
如果只在后⾯加维度,前⾯的维度不变,⽐如开始的是3⾏2列,3⾏那在第⼀个维度(也就是Axis =0,轴序号是从0开始)始终都不会变,⽆论加多少维度,第⼀维度始终是3。
三维数组的轴示意图
⼀旦超过三维,就很难想象或理解,我们还可以⽤树形结构来表⽰多维关系,这种表达⽅式不需要结构或空间想象。有⼏维,就有⼏个轴(时刻不要忘记轴序号以0开始)。
三维数组的树形结构
数组比较
A[:,:,np.newaxis, np.newaxis] == B是将两个数组进行比较。数组的比较非常有意思,可以直接用==进行运算。
- 比较是有前提的,即结构绝对相同。
- 若结构相同,同样位置的值,⼀⼀⽐较,形成⼀个新的结构相同的数组,数组⾥只有布尔类型。
- 若结构不同,则直接返回
false并报错。
#代码示例(结构相同)
A = np.array([[1,2,3],[4,5,6]])
B = np.array([[1,2,5],[4,4,5]])
A == B## array([[ True, True, False],
## [ True, False, False]])
#代码示例(结构不同)
d2 = np.arange(20).reshape(4,5)
d3 = np.arange(20).reshape(5,4)
print (d2 == d3)## False
##
## <string>:1: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
A = np.array([10,7,2,0])
B = np.array([10,9,8,7,6,5,4,3,2,1])
print (A == B)## False
如果两个数组结构不同呢?怎么比较才不会报错?可以升高维度。
⽐如两个结构不同的1维数组,我们把其中一个1维数组升维以后,变成2维,再和另⼀个1维数组⽐较
#代码示例(1维升2维以后)
C = A [...,np.newaxis]
print (C)## [[10]
## [ 7]
## [ 2]
## [ 0]]
print (C == B)## [[ True False False False False False False False False False]
## [False False False True False False False False False False]
## [False False False False False False False False True False]
## [False False False False False False False False False False]]
升了⼀个维度,相当于就可以让B中的每个数去和升维后的A(即C)去⼀个⼀个⽐较,并最终产⽣了⼀个2维的新数组(4⾏10列)。升⾼维度就是为了能够和不同形状的数组进⾏单个元素⽐较,⽽且升⾼的维度必须为原来的N倍,才可能进⾏⽐较,⽐如原来是2维,那就得变4维甚至6维才能和其比较。
有人可能会问,升维以后,那些没填满的值都是什么?我没有仔细查看资料,但我认为存在两个特性。
- 继承,即下一维度都是这个值
- 可变,即可随着比较的数组而变化
比如上面的例子,产⽣了⼀个2维的新数组(4⾏10列),因为要比较的数组B有10个数值,如果B只有5个数值或者3个数值呢?我们可以运行这个代码看看。
#代码示例(1维升2维以后再示例)
A = np.array([10,7,2,0])
B = np.array([10,9,2])
C = A [...,np.newaxis]
print (C == B)## [[ True False False]
## [False False False]
## [False False True]
## [False False False]]
由于维度不高,所以我们还是可以用一个表格来表示
2维数组和1维数组比较
通过上面代码的运行结果和这个示意图,我们可以看到在坐标(0,0)和(2,2)是TRUE。
轴的折叠
想要学会np.where这个函数,必须先理解轴的折叠,这也是numpy比较难的知识点。
数组轴的个数,在python的世界中,轴的个数被称作秩,轴的个数与数组中括号“[”或者”]“的个数相同,轴指向多维数组的单个维度:
#数一数维数和[]的关系
A = np.random.randint(0,5,(2,3,4))
print (A)## [[[1 2 3 0]
## [2 3 1 1]
## [3 4 1 2]]
##
## [[1 0 4 3]
## [2 0 1 0]
## [3 2 3 4]]]
官⽅⽂档的解释:在numpy进行sum或where等方法进行运算时,指定轴的⽅式可能会让来⾃其他语⾔的⽤户感到困惑。axis关键字指定将折叠的数组的维度,⽽不是将要返回的维度。因此,指定axis=0意味着第⼀个轴将折叠:对于⼆维数组,这意味着行不存在了,只存在列,每列中的值将被聚合。我们先随便用一维数组试试,同样的例子。
A = np.array([10,7,2,0])
B = np.array([10,9,8,7,6,5,4,3,2,1])
np.where (B)## (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64),)
C = A [...,np.newaxis] == B
print (C)## [[ True False False False False False False False False False]
## [False False False True False False False False False False]
## [False False False False False False False False True False]
## [False False False False False False False False False False]]
print ('C location:',np.where (C))
#输出数组中不为0的坐标## C location: (array([0, 1, 2], dtype=int64), array([0, 3, 8], dtype=int64))
print ('Axis 1 combine:', C.any(1), np.where(C.any(1)))## Axis 1 combine: [ True True True False] (array([0, 1, 2], dtype=int64),)
print ('Axis 0 combine:', C.any(0), np.where(C.any(0)))## Axis 0 combine: [ True False False True False False False False True False] (array([0, 3, 8], dtype=int64),)
numpy.where有2种⽤法,
- np.where(condition, x, y),满⾜条件(condition),输出x,不满⾜输出y。
np.where(A > 5,1,-1)## array([ 1, 1, -1, -1])
- np.where(condition)
只有条件 (condition),没有x和y,则输出满⾜条件 (即⾮0) 元素的轴 (等价于numpy.nonzero)。这⾥的轴以tuple(元组)的形式给出,通常原数组有多少维,输出的tuple中就包含⼏个数组,分别对应符合条件元素的各轴,即可推断坐标。这⾥numpy.where显然属于第2种⽤法。
np.where(C.any(1))意思是把第2个轴(即axis-1)折叠了,即把列折叠了,这样就没有列了,只剩行。从而输出的是已经把列折叠以后的行数。我们还可以理解为在这一行,只要有一个数不是0,那就输出该行的位置。
np.where(C.any(0))则是把行折叠了,输出的列数。理解同上。
上⾯的代码,也是轴的折叠,是⽤where实现,numpy很多命令都是可以实现轴的折叠的,⽐如我们可以⽤numpy.sum函数。2维结构表示,⽆⾮就和前⾯说的⼀样,就是⾏和列的操作。
A = np.arange(1,11).reshape(2,5)
print (A,"\n")## [[ 1 2 3 4 5]
## [ 6 7 8 9 10]]
sum0 = A.sum(0)
sum1 = A.sum(1)
print("0轴折叠相加\n",sum0)## 0轴折叠相加
## [ 7 9 11 13 15]
print("1轴折叠相加\n",sum1)## 1轴折叠相加
## [15 40]
回到我们的93题,我们先从外到内对np.where语句进行拆解。
np.random.seed(24)
A = np.random.randint(0,5,(8,3))
B = np.random.randint(0,5,(2,2))
C = (A[:,:,np.newaxis, np.newaxis] == B)
print (A,'\n\n',B)## [[2 3 0]
## [1 1 1]
## [4 3 4]
## [3 2 3]
## [3 3 3]
## [1 2 1]
## [3 4 0]
## [4 1 1]]
##
## [[4 1]
## [1 1]]
## (array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 5, 5, 5, 5, 5, 5, 6, 7, 7, 7, 7,
## 7, 7, 7], dtype=int64), array([0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 2, 0, 0, 0, 2, 2, 2, 1, 0, 1, 1, 1,
## 2, 2, 2], dtype=int64), array([0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1,
## 0, 1, 1], dtype=int64), array([1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1,
## 1, 0, 1], dtype=int64))
A = np.random.randint(0,5,(8,3))
8,3树形结构
B = np.random.randint(0,5,(2,2))
2,2树形结构
A[:,:,np.newaxis, np.newaxis]
升维以后,版面不够,我先把第0轴第2个元素提出来了,即Axis0[1],因为这个满足了题目的筛选条件。你知道A升维以后,再和B进行比较,数组编程了多少个元素吗?应该是8*3*2*2=96
第0轴第2个元素
下面我们以Axis0[1]为例,折叠轴了。看好了,首先折叠第4轴,也就是最下面的轴
np.where(C.any((3)))
第4轴折叠
np.where(C.any((3,1)))
第2轴折叠
其实下面的步骤我已经不用再解释了,你应该懂了。你可以自己画一画。 np.where(C.any((3,1)).all(1))
同样的,我们选取一个不满足要求的行,比如Axis0[2],当你看到下图的时候,你就立刻明白为它如何不满足要求了。
第0轴第2个元素