Mojo 学习 —— buffer
前言
我们接下去学习内置的 buffer
包,这个模块主要是对指针进行封装,其底层还是对指针进行操作。
它还提供了一种 NDBuffer
类,可以将指针封装为多维数据,便于以矩阵的方式访问数据。
还有一个需要注意的是,内置的 reduction
模块提供的聚集函数都是基于 buffer
类型的。所以还是有必要学习一下 buffer
模块的使用。
buffer
包中包含两个模块
buffer
:实现了Buffer
类list
:提供处理静态和可变列表的实用程序
list
list
模块中实现了两个结构体
Dim
:使用一个可选的整数来构建静态(指定值)或动态维度(未指定值)
其底层封装的是 OptionalReg[Int]
类型
from buffer.list import Dim, DimList
from testing import assert_equal
fn main() raises:
var static_dim = Dim(10)
var dynamic_dim = Dim()
assert_equal(static_dim.has_value(), True)
assert_equal(static_dim.get(), 10)
print('static dim:', static_dim)
assert_equal(dynamic_dim.has_value(), False)
print('dynamic dim:', dynamic_dim)
# static dim: 10
# dynamic dim: ?
DimList
:表示维度列表。每个维度可能是一个静态值,也可能没有值(动态维度)
维度列表,封装的是 VariadicList[Dim]
类型,接受可变数量的维度值。例如
from buffer.list import Dim, DimList
from testing import assert_equal
fn main() raises:
var static_dim = Dim(10)
var dynamic_dim = Dim()
var dlist = DimList(3, static_dim, dynamic_dim)
assert_equal(dlist.get[0](), 3)
assert_equal(dlist.get[1](), 10)
assert_equal(dlist.get[2](), 0)
assert_equal(dlist.at[2](), '?')
assert_equal(dlist.all_known[3](), False)
assert_equal(dlist.contains[2](3), True)
assert_equal(dlist.product[2](), 30)
buffer
该模块包含三个结构体
Buffer
: 可以通过编译时参数化的方式指定大小和类型NDBuffer
: 多维Buffer
DynamicRankBuffer
: 表示维度、形状和数据类型未知的Bffer
类型
Buffer
Buffer
底层封装的是一个指针,但不拥有该指针。例如
from buffer.list import Dim, DimList
from buffer import Buffer
from testing import assert_equal
fn main():
var ptr = DTypePointer[DType.int8].alloc(10)
var buf = Buffer[DType.int8, 10](ptr)
print('buffer size:', len(buf))
buf.store[width=4](0, SIMD[DType.int8, 4](9, 5, 2, 7))
buf.simd_nt_store[width=4](4, SIMD[DType.int8, 4](7, 2, 5, 9))
buf[8] = 1
print('buffer bytecount: ', buf.bytecount())
print('buffer values: ', buf.load[width=8](0), buf.load[width=2](8))
print('second value:', ptr[1]) # 使用指针获取值
ptr.free()
# buffer size: 10
# buffer bytecount: 10
# buffer values: [9, 5, 2, 7, 7, 2, 5, 9] [1, 0]
# second value: 5
我们可以使用指针来获取值,并且要手动将释放内存,Buffer
更像是一个代理。
如果我们将长度设置为 8
,并不会影响对后续元素的获取,还是可以直接从指针中获取到超出长度范围的值。例如,修改上面的代码
var buf = Buffer[DType.int8, 8](ptr)
# buffer size: 8
# buffer bytecount: 8
# buffer size: [9, 5, 2, 7, 7, 2, 5, 9] [1, 0]
我们可以使用指定值来一次性修改所有值。例如,添加下面的代码
buf.zero()
print('buffer values: ', buf.load[width=8](0), buf.load[width=2](8))
buf.fill(2)
print('buffer values: ', buf.load[width=8](0), buf.load[width=2](8))
# uffer values: [0, 0, 0, 0, 0, 0, 0, 0] [1, 0]
# buffer values: [2, 2, 2, 2, 2, 2, 2, 2] [1, 0]
可以看到,只修改了 buf
长度内的值,超出长度的值无法使用这种方式修改。
如果我们传入的是动态的维度,则需要在运行时参数中指定长度。例如
alias dim = Dim()
var buf = Buffer[DType.int8, dim](ptr, 8)
NDBuffer
NDBuffer
用来创建多维 Buffer
数据,编译时参数的 size
变成了 rank
(维度) 和 shape
(大小)。
例如,创建一个 3x4
的 NDBuffer
对象
from buffer.list import Dim, DimList
from buffer import NDBuffer
from testing import assert_equal
fn main() raises:
var ptr = DTypePointer[DType.int8].alloc(12)
alias rank = 2
alias shape = DimList(3, 4)
var buf = NDBuffer[DType.int8, rank, shape](ptr)
# var buf = NDBuffer[DType.int8, rank](ptr, shape)
buf.fill(1)
assert_equal(buf.size(), 12)
assert_equal(buf.num_elements(), 12)
assert_equal(len(buf), 12)
print(buf)
ptr.free()
# NDBuffer([[1, 1, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1]], dtype=int8, shape=3x4)
也可以使用 StaticIntTuple
来初始化
alias shape = StaticIntTuple[rank](3, 4)
var buf = NDBuffer[DType.int8, rank](ptr, shape)
获取维度信息
assert_equal(buf.get_rank(), 2)
assert_equal(buf.get_shape(), StaticIntTuple[2](3, 4))
assert_equal(buf.get_nd_index(2), StaticIntTuple[2](0, 2))
其中,get_nd_index
用于将元素在指针中的一维位置信息转换为 NDBuffer
的多维位置坐标。
存储和访问数据
from buffer.list import Dim, DimList
from buffer import NDBuffer
from testing import assert_equal
fn main() raises:
var ptr = DTypePointer[DType.int8].alloc(12)
alias rank = 2
alias shape = DimList(3, 4)
var buf = NDBuffer[DType.int8, rank](ptr, shape)
buf.simd_nt_store[width=8]((0, 0), SIMD[DType.int8, 8](1, 2, 3, 4, 5, 6, 7, 8))
buf.store[width=4]((2, 0), SIMD[DType.int8, 4](9, 10, 11, 12))
print(buf)
print(buf.dim(1))
print(buf.load[width=4](0, 1))
print(buf[2, 0])
ptr.free()
# NDBuffer([[1, 2, 3, 4],
# [5, 6, 7, 8],
# [9, 10, 11, 12]], dtype=int8, shape=3x4)
# 4
# [2, 3, 4, 5]
# 9
flatten
可以将其转换为一维的 Buffer
对象
var flat = buf.flatten()
for i in range(len(flat)):
print(flat[i], end=' ')
我们还可以设置每个维度的步长。例如上面的例子中,数据的维度为 3x4
,也就是说,第一维的步长为 4
,第二维的步长为 1
。这里说的步长是针对指针所表示的一维数据而言,如果以行列来代表两个维度,则行每次跨越 4
个数,列每次跨越 1
个数。
知道了这个概念,可以很容易的理解下面的代码
var buf = NDBuffer[DType.int8, rank](ptr, shape, (4, 2))
buf.simd_nt_store[width=8]((0, 0), SIMD[DType.int8, 8](1, 2, 3, 4, 5, 6, 7, 8))
print(buf)
print('strid at 1-dim:', buf.stride(0))
print('strid at 2-dim:', buf.stride(1))
print('value at (0, 1)', buf[0, 1])
print('value at (0, 2)', buf[0, 3])
print('value at (1, 1)', buf[1, 1])
# NDBuffer([[1, 2, 3, 4],
# [5, 6, 7, 8],
# [0, 0, 0, 0]], dtype=int8, shape=3x4)
# strid at 1-dim: 4
# strid at 2-dim: 2
# value at (0, 1) 3
# value at (0, 2) 7
# value at (1, 1) 7
每一行的步长为 4,每一列的步长为 2,所以第一行为 1,3,5,7
,第二行的开头是 5
。
tile 可以进行取子集的操作,例如
print(buf.tile[1, 4]((1, 0)))
DynamicRankBuffer
它的效率不如前面的对象,但在与外部函数交互时非常有用,特别是形状被表示为一个固定(_MAX_RANK
)维度的数组,以简化 ABI
。
DynamicRankBuffer
表示未知维度,形状和类型的对象。例如
fn main() raises:
var ptr = DTypePointer[DType.invalid].alloc(8)
memset_zero(ptr, 8)
var buf = DynamicRankBuffer(ptr, 2, (2, 4), DType.int8)
assert_equal(buf.dim(1), 4)
print(buf.to_ndbuffer[DType.int8, 2]())
var buffer = buf.to_buffer[DType.int8]()
assert_equal(len(buffer), 8)
ptr.free()
# NDBuffer([[0, 0, 0, 0],
# [0, 0, 0, 0]], dtype=int8, shape=2x4)
functions
该模块还提供了两个函数,可以对指针数据进行范围填充。例如
fn main():
var ptr = DTypePointer[DType.int8].alloc(4)
memset_zero(ptr, 4)
# 填充 [1,3) 索引区间的值
partial_simd_store[4](ptr, 1, 3, SIMD[DType.int8, 4](0, 1, 3, -1))
print(ptr.load[width=4]())
# 获取值,超出位置的设置为默认值
print(partial_simd_load[4](ptr, 1, 3, -1))
ptr.free()
# [0, 1, 3, 0]
# [-1, 1, 3, -1]
reduce
针对 Buffer
类型的数据,标准库 algorithm
下的 reduce
模块实现了很多计算函数。
判断函数
有三个函数可以快速判断 Buffer
类型的对象中是否有全为 True
或 False
,已经是否包含 True
。例如
from algorithm import all_true, any_true, none_true
from buffer import Buffer
from testing import assert_equal
fn main() raises:
var ptr = DTypePointer[DType.bool].alloc(6)
memset_zero(ptr, 6)
var buf = Buffer[DType.bool, 6](ptr)
assert_equal(none_true(buf), True)
buf[2] = True
assert_equal(any_true(buf), True)
assert_equal(all_true(buf), False)
buf.fill(1)
assert_equal(all_true(buf), True)
ptr.free()
统计函数
from buffer import Buffer
from random import rand, seed
from algorithm import sum, max, min, variance, mean
fn main() raises:
alias size = 16
var ptr = DTypePointer[DType.float64].alloc(size)
seed(1234)
rand(ptr, size)
var buf = Buffer[DType.float64, size](ptr)
print('Sum:', sum(buf))
print('Min:', min(buf), 'Max:', max(buf))
print('Mean:', mean(buf), 'Var:', variance(buf))
ptr.free()
累积函数
最简单的累积函数是累加和累乘
fn main() raises:
alias size = 4
var ptr = DTypePointer[DType.int8].alloc(size)
ptr.store(0, SIMD[DType.int8](1, 3, 3, 4))
var src = Buffer[DType.int8, size](ptr)
print('product', product(src))
var ptr2 = DTypePointer[DType.int8].alloc(size)
memset_zero(ptr2, size)
var dest = Buffer[DType.int8, size](ptr2)
cumsum(dest, src)
print('cumsum:', dest.load[width=size](0))
ptr.free()
ptr2.free()
# product 36
# cumsum: [1, 4, 7, 11]
更复杂点的布尔判断函数。例如
fn main() raises:
alias size = 4
var ptr = DTypePointer[DType.int8].alloc(size)
ptr.store(0, SIMD[DType.int8](1, 3, 3, 4))
var src = Buffer[DType.int8, size](ptr)
@parameter
fn reduce_fn[type: DType, width: Int](val: SIMD[type, width], /) -> Bool:
return val % 2 == 0
@parameter
fn continue_fn(val: Bool) -> Bool:
if val: print('even')
else: print('odd')
return True
var res = reduce_boolean[reduce_fn, continue_fn](src, False)
print(res)
ptr.free()
其中 reduce_fn
用于获取每一个值,并返回一个布尔值,continue_fn
可以获取 reduce_fn
的返回值,并返回一个布尔值。如果 continue_fn
返回 False
则会立马退出,发挥结果将为 False
。
第二个参数为初始值,是一个布尔值。
再来一个更复杂点的函数
from buffer import Buffer
from random import rand, seed
from algorithm import reduce, map_reduce, reduce_boolean
fn reduce_vec_to_scalar_fn[type: DType, width: Int](val: SIMD[type, width], /) -> SIMD[type, 1]:
return val.reduce_add()
fn main() raises:
alias size = 8
var ptr = DTypePointer[DType.int8].alloc(size)
seed(1234)
rand(ptr, size)
var src = Buffer[DType.int8, size](ptr)
var ptr2 = DTypePointer[DType.int8].alloc(size)
memset_zero(ptr2, size)
var dest = Buffer[DType.int8, size](ptr2)
@parameter
fn input_gen_fn[type: DType, width: Int](index: Int) -> SIMD[type, width]:
return src.load[width=width](index).cast[type]()
@parameter
fn reduce_vec_to_vec_fn[dtype: DType, stype: DType, width: Int](dest: SIMD[dtype, width], src: SIMD[stype, width]) -> SIMD[dtype, width]:
print('src:', src)
print('dest:', dest)
return dest + src.cast[dtype]()
var res = map_reduce[4, size, DType.int8, DType.int8, input_gen_fn, reduce_vec_to_vec_fn, reduce_vec_to_scalar_fn](dest, 0)
print(res)
print(dest.load[width=size](0))
print(src.load[width=size](0))
ptr.free()
ptr2.free()
其中,input_gen_fn
用于创建向量,reduce_vec_to_vec_fn
操作两个长度相同的向量,reduce_vec_to_scalar_fn
用于将向量元素聚合为一个标量值。
fn main() raises:
alias size = 8
var ptr = DTypePointer[DType.int8].alloc(size)
ptr.store[width=size](0, SIMD[DType.int8, size](3, 3, 6, 7, 8, 5, 4, 9))
var src = Buffer[DType.int8, size](ptr)
print(src.load[width=size](0))
@parameter
fn reduce_fn[stype: DType, dtype: DType, width: Int](src: SIMD[stype, width], dest: SIMD[dtype, width]) -> SIMD[stype, width]:
return src.min(dest.cast[stype]())
var res = reduce[reduce_fn](src, SIMD[DType.int8, 1](10))
print(res)
ptr.free()
这个函数也可以应用于 NDBuffer
类型,用于对矩阵某一维度进行聚合。例如
fn main() raises:
alias size = 8
var ptr = DTypePointer[DType.int8].alloc(size)
ptr.store[width=size](0, SIMD[DType.int8, size](3, 3, 6, 7, 8, 5, 4, 9))
var src = NDBuffer[DType.int8, 3, (2, 2, 2)](ptr)
print(src)
var ptr2 = DTypePointer[DType.int8].alloc(4)
memset_zero(ptr2, 4)
var dest = NDBuffer[DType.int8, 2](ptr2, DimList(2, 2))
@parameter
fn map_fn[stype: DType, dtype: DType, width: Int](src: SIMD[stype, width], dest: SIMD[dtype, width]) -> SIMD[stype, width]:
return src.min(dest.cast[stype]())
fn reduce_fn[type: DType, width: Int](val: SIMD[type, width]) -> SIMD[type, 1]:
return val.reduce_min()
reduce[map_fn, reduce_fn, 1](src, dest, SIMD[DType.int8, 1](10))
print(dest)
ptr.free()
# NDBuffer([[[3, 3],
# [6, 7]],
# [[8, 5],
# [4, 9]]], dtype=int8, shape=2x2x2)
# NDBuffer([[3, 3],
# [4, 5]], dtype=int8, shape=2x2)
将 2x2x2
的 NDBuffer
根据第二维进行聚合,保留较小的那个值。