Mojo 学习 —— 指针
文章目录
这一节我们学习一下
Mojo
内存管理模块,这个包里面包含
4
个子模块。包括
memory
:定义了操作内存的函数reference
:定义了Reference
类型即操作unsafe
:不安全指针类型`unsafe_pointer
:通用不安全指针类型
unsafe
unsafe
模块下面定义了两种类型的指针结构:
LegacyPointer
:定义包含一个寄存器可传递类型地址的结构。DTypePointer
:定义包含给定dtype
类型地址的结构。
Pointer
是LegacyPointer
的别名
LegacyPointer
主要用于存储和操作寄存器可传递类型(整数、浮点数和布尔值等)。例如
from testing import assert_equal, assert_true, assert_not_equal
fn main() raises:
var nullptr = Pointer[Int]()
assert_equal(str(nullptr), "0x0")
var ptr = Pointer[Int].alloc(1)
print(ptr)
assert_true(str(ptr).startswith("0x"))
assert_not_equal(str(ptr), "0x0")
ptr.free()
# 0x56267776c010
assert_equal
断言两个对象相等,assert_true
断言条件为真,assert_not_equal
断言两个对象不相等
将指针转换为字符串时,将返回十六进制字符串,包含该指针目标内存位置的十六进制表示。
也可以使用静态方法构建一个空指针
var nullptr = Pointer[Int].get_null()
记住使用完记得调用 free
方法释放内存。
指针可以使用 []
进行解引用
var ptr = Pointer[Int].alloc(1)
ptr[] = 42
assert_equal(ptr[], 42)
ptr.free()
将指针指向可变字符串
alias payload = "$Modular!Mojo!HelloWorld^"
var ptr = Pointer[String].alloc(1)
__get_address_as_uninit_lvalue(ptr.address) = String()
ptr[] = payload
assert_equal(ptr[], payload)
ptr.free()
其中 __get_address_as_uninit_lvalue = String()
代码返回一个未初始化的内存,并在未初始化的内存中初始化一个新值。
使用指针来指向自定义结构体
@value
struct Pair:
var first: Int
var second: Int
fn main() raises:
var ptr = Pointer[Pair].alloc(1)
ptr[].first = 42
ptr[].second = 24
assert_equal(ptr[].first, 42)
assert_equal(ptr[].second, 24)
ptr.free()
分别使用 store
和 load
方法来存储和访问指针指向的值
var size = 3
var ptr = Pointer[Int].alloc(size)
for i in range(size):
ptr.store(i, i)
for i in range(size):
print(ptr.load(i))
ptr.free()
此外,
Pointer
还支持nt_store
,即非时效存储,指的是跳过处理器缓存,直接将数据写入主内存的操作,是一种用于优化大数据块内存写入操作的技术。这种情况下要求地址必须正确对齐,即数据在内存中的起始地址必须是特定大小的倍数。比如
avx512
为64
字节,avx2
为32
字节,avx
为16
字节。
使用 bitcast
可以转换指针指向值的类型,但转换有可能会损失精度。例如
var new_ptr = ptr.bitcast[Bool]()
for i in range(size):
print(new_ptr.load(i))
new_ptr.free()
# False
# False
# False
可以使用 offset
让指针偏移,例如
var new_ptr = ptr
for _ in range(3):
print(new_ptr[])
new_ptr = new_ptr.offset(1)
也可以使用加法运算,让指针偏移。例如上面的代码与下面两种方式一样
for i in range(3):
print((ptr + i)[])
for i in range(3):
print(ptr[i])
这点有点类似 C
语言中的指针自增遍历指针指向的元素,可以使用加减以及大小比较等操作。
使用静态方法 address_of
可以直接获取对象的地址。例如
var pair = Pair(1, 2)
var ptr = Pointer.address_of(pair)
print(ptr[].first, ptr[].second)
# 1 2
DTypePointer
DTypePointer
结构体主要用于操作 DType
类型。基本上 Pointer
支持的操作 DTypePointer
都支持,还有一些额外的方法。
例如,定义空指针及一个长度为 1
的指针
var nullptr = DTypePointer[DType.float32]()
assert_equal(str(nullptr), "0x0")
var ptr = DTypePointer[DType.float32].alloc(1)
assert_true(str(ptr).startswith("0x"))
assert_not_equal(str(ptr), "0x0")
ptr.free()
DTypePointer
可以使用下标加索引的方式访问或设置指针对应位置的值。例如
alias size = 4
var a = DTypePointer[DType.int32].alloc(size)
for i in range(size):
a[i] = -i
for i in range(size):
print(a[i], end=' ')
a.free()
还可以使用 store
和 load
快速存储和获取 SIMD
向量。例如
a.store(0, SIMD[DType.int32, 4](9, 5, 2, 7))
print(a.load[width=4]())
类似地,可以使用
simd_nt_store
方法,以非时效存储的方式来存储 SIMD 向量。
DTypePointer
还支持以跨越的方式存储和加载 SIMD
向量。例如
var a = DTypePointer[DType.int32].alloc(size)
a.store(0, SIMD[DType.int32, 4](9, 5, 2, 7))
# strided store: [1, 5, 3, 7]
a.simd_strided_store[2](SIMD[DType.int32, 2](1, 3), 2)
# strided load [1, 7]
assert_equal(a.simd_strided_load[2](3), SIMD[DType.int32, 2](1, 7))
只要设置需要存储或读取的向量长度,以及每个向量元素之间的跨度(步长)即可。
还有两个更高级点的方法,用于修改值或获取值。
scatter
scatter
可以根据传入的偏移向量以及值向量来修改指针对应位置的值。其中偏移向量必须是整数,且长度必须是 2
的幂。例如
from testing import assert_almost_equal
fn main() raises:
var ptr = DTypePointer[DType.float32].alloc(4)
ptr.store(0, SIMD[ptr.type, 4](0.0))
# 定义测试函数
@parameter
def _test_scatter[
width: Int
](
offset: SIMD[_, width],
val: SIMD[ptr.type, width],
desired: SIMD[ptr.type, 4],
):
ptr.scatter(offset, val)
var actual = ptr.load[width=4](0)
assert_almost_equal(
actual, desired, msg="_test_scatter", atol=0.0, rtol=0.0
)
# 将第三个值设置为 2
_test_scatter[1](UInt16(2), 2.0, SIMD[ptr.type, 4](0.0, 0.0, 2.0, 0.0))
# 将第二个值设置为 1,多次出现同一个位置取最后依次出现时的值
_test_scatter(
SIMD[DType.uint32, 4](1, 1, 1, 1),
SIMD[ptr.type, 4](-1.0, 2.0, -2.0, 1.0),
SIMD[ptr.type, 4](0.0, 1.0, 2.0, 0.0),
)
# 倒序设置值
_test_scatter(
SIMD[DType.uint64, 4](3, 2, 1, 0),
SIMD[ptr.type, 4](0.0, 1.0, 2.0, 3.0),
SIMD[ptr.type, 4](3.0, 2.0, 1.0, 0.0),
)
# 释放内存
ptr.free()
这种方式可以很方便地修改数据中任意位置的值。还可以使用掩码的方式只修改掩码为 True
时对应位置的值。
from testing import assert_almost_equal
fn main() raises:
var ptr = DTypePointer[DType.float32].alloc(4)
ptr.store(0, SIMD[ptr.type, 4](0.0))
# 测试函数
@parameter
def _test_masked_scatter[
width: Int
](
offset: SIMD[_, width],
val: SIMD[ptr.type, width],
mask: SIMD[DType.bool, width],
desired: SIMD[ptr.type, 4],
):
ptr.scatter(offset, val, mask)
var actual = ptr.load[width=4](0)
assert_almost_equal(
actual, desired, msg="_test_masked_scatter", atol=0.0, rtol=0.0
)
# False,不修改值
_test_masked_scatter[1](
Int16(2), 2.0, False, SIMD[ptr.type, 4](0.0, 0.0, 0.0, 0.0)
)
# True,将位置 3 的值修改为 2
_test_masked_scatter[1](
Int32(2), 2.0, True, SIMD[ptr.type, 4](0.0, 0.0, 2.0, 0.0)
)
# 执行前三次修改,对同一个位置(1)修改了三次,结果为 -2
_test_masked_scatter( # Test with repeated offsets
SIMD[DType.int64, 4](1, 1, 1, 1),
SIMD[ptr.type, 4](-1.0, 2.0, -2.0, 1.0),
SIMD[DType.bool, 4](True, True, True, False),
SIMD[ptr.type, 4](0.0, -2.0, 2.0, 0.0),
)
# 偏移为倒序,不修改第二个值
_test_masked_scatter(
SIMD[DType.index, 4](3, 2, 1, 0),
SIMD[ptr.type, 4](0.0, 1.0, 2.0, 3.0),
SIMD[DType.bool, 4](True, False, True, True),
SIMD[ptr.type, 4](3.0, 2.0, 2.0, 0.0),
)
ptr.free()
由于 SIMD
向量长度必须是 2
的幂,掩码方式的加入,可以让我们的修改更自由,可以不用全部修改。
gather
gather
方法可以根据传入的偏移向量来依次获取指针对应位置的值。例如
from testing import assert_equal, assert_almost_equal
fn main() raises:
var ptr = DTypePointer[DType.float32].alloc(4)
ptr.store(0, SIMD[ptr.type, 4](0.0, 1.0, 2.0, 3.0))
# 定义测试函数
@parameter
def _test_gather[
width: Int
](offset: SIMD[_, width], desired: SIMD[ptr.type, width]):
var actual = ptr.gather(offset)
assert_almost_equal(
actual, desired, msg="_test_gather", atol=0.0, rtol=0.0
)
# 需要获取值的位置
var offset = SIMD[DType.int64, 8](3, 0, 2, 1, 2, 0, 3, 1)
# 提取值后的结果
var desired = SIMD[ptr.type, 8](3.0, 0.0, 2.0, 1.0, 2.0, 0.0, 3.0, 1.0)
# 测试
_test_gather[1](UInt16(2), 2.0)
_test_gather(offset.cast[DType.uint32]().slice[2](), desired.slice[2]())
_test_gather(offset.cast[DType.uint64]().slice[4](), desired.slice[4]())
# 释放内存
ptr.free()
还可以使用掩码的方式,在掩码为 False
时,不取出 offset
对应位置的值,而是返回默认值,为 True
时,将 offset
对应位置的值取出。例如
from testing import assert_equal, assert_almost_equal
fn main() raises:
var ptr = DTypePointer[DType.float32].alloc(4)
ptr.store(0, SIMD[ptr.type, 4](0.0, 1.0, 2.0, 3.0))
# 测试掩码的函数
@parameter
def _test_masked_gather[
width: Int
](
offset: SIMD[_, width],
mask: SIMD[DType.bool, width],
default: SIMD[ptr.type, width],
desired: SIMD[ptr.type, width],
):
var actual = ptr.gather(offset, mask, default)
assert_almost_equal(
actual, desired, msg="_test_masked_gather", atol=0.0, rtol=0.0
)
var offset = SIMD[DType.int64, 8](3, 0, 2, 1, 2, 0, 3, 1)
var desired = SIMD[ptr.type, 8](3.0, 0.0, 2.0, 1.0, 2.0, 0.0, 3.0, 1.0)
# 将 3 的位置设置为 False
var mask = (offset >= 0) & (offset < 3)
# 默认值为 -1
var default = SIMD[ptr.type, 8](-1.0)
# 使用掩码之后的值
desired = SIMD[ptr.type, 8](-1.0, 0.0, 2.0, 1.0, 2.0, 0.0, -1.0, 1.0)
# 测试
_test_masked_gather[1](Int16(2), False, -1.0, -1.0)
_test_masked_gather[1](Int32(2), True, -1.0, 2.0)
_test_masked_gather(offset, mask, default, desired)
ptr.free()
bitcast
该模块中还提供了一个 bitcast
函数,用于将一个 SIMD
值转换为另一个 SIMD
值。例如
assert_equal(
bitcast[DType.int8, 8](SIMD[DType.int16, 4](1, 2, 3, 4)),
SIMD[DType.int8, 8](1, 0, 2, 0, 3, 0, 4, 0),
)
assert_equal(
bitcast[DType.int32, 1](SIMD[DType.int8, 4](0xFF, 0x00, 0xFF, 0x55)),
Int32(1442775295),
)
unsafe_pointer
该模块实现了通用不安全指针类型。定义了 UnsafePointer
结构体来表示不安全类型的指针,以及 5
个操作该结构体指针的函数。
UnsafePointer
定义了一种可以指向任何可移动通用值的指针类型,其定义的方法与 Pointer
基本类似。例如
from testing import assert_equal
fn main() raises:
var ptr = UnsafePointer[Int].alloc(1)
ptr[0] = 0
ptr[] += 1
assert_equal(ptr[], 1)
ptr.free()
定义、访问及解引用方式是一样的。
var ptr = UnsafePointer[Int].alloc(5)
for i in range(5):
ptr[i] = i
for i in range(5):
assert_equal(ptr[i], i)
ptr.free()
使用 address_of
静态方法
var local = 1
assert_not_equal(0, int(UnsafePointer[Int].address_of(local)))
将指针转换为不同的类型
var local = 1
var ptr = UnsafePointer[Int].address_of(local)
var aliased_ptr = ptr.bitcast[SIMD[DType.uint8, 4]]()
assert_equal(int(ptr), int(ptr.bitcast[Int]()))
assert_equal(int(ptr), int(aliased_ptr))
对指针进行运算
var p1 = UnsafePointer[Int].alloc(1)
assert_true((p1 - 1) < p1)
assert_true((p1 - 1) <= p1)
assert_true(p1 <= p1)
assert_true((p1 + 1) > p1)
assert_true((p1 + 1) >= p1)
assert_true(p1 >= p1)
p1.free()
函数使用
使用 destroy_pointee
函数来销毁指针指向的值,要求指针不为空。例如
var a = 1
var p = UnsafePointer[Int].address_of(a)
destroy_pointee(p)
可以直接将值拷贝或移动到指针所指向的位置上。例如
from testing import assert_not_equal, assert_equal
fn main() raises:
var a = List(10, 2)
var p1 = UnsafePointer[List[Int]].alloc(1)
initialize_pointee_move(p1, a)
# initialize_pointee_copy(p1, a)
assert_equal(p1[0][0], 10)
assert_not_equal(p1[0][1], 10)
反过来,可以将一个值从指针指向的位置中移出,并结束存储在该指针内存位置的值的生存期,随后对该指针的读取将无效。例如
var value = move_from_pointee(ptr)
assert_equal(value[0], 10)
assert_not_equal(value[1], 10)
如果使用 initialize_pointee_move
为指针存储了一个新的值,则可以重新从该指针处读取数据。
还可以在一个指针上的数据转移到另一个指针所指向的地址
from testing import assert_equal, assert_not_equal
from memory.unsafe_pointer import move_pointee
fn main() raises:
var a = List(10, 2)
var ptr1 = UnsafePointer[List[Int]].alloc(1)
initialize_pointee_move(ptr1, a)
var ptr2 = UnsafePointer[List[Int]].alloc(1)
move_pointee(src=ptr1, dst=ptr2)
assert_equal(ptr2[0][1], 2)
reference
这个模块中主要定义了两个结构体:
AddressSpace
:定义了不同的地址空间,例如0
表示通用型(generic
)。- Reference:定义一个不为空的安全引用。
使用 Reference
创建一个引用
var a = 1
var ref = Reference(a)
assert_equal(ref[], a)
var ptr = ref.get_legacy_pointer()
ptr[] = 10
assert_equal(a, 10)
memory
该模块包含 5
个函数:
stack_allocation
: 根据给定数据类型和元素个数,在栈上分配数据空间并返回一个指针。这种创建指针的方式不需要释放内存。例如
from testing import assert_equal, assert_almost_equal
fn main() raises:
# DTypePointer
var ptr1 = stack_allocation[2, DType.float32]()
ptr1.store(SIMD[DType.float32, 2](3.14, 0.618))
assert_almost_equal(ptr1[0], 3.14)
# LegacyPointer
var ptr2 = stack_allocation[3, Int]()
for i in range(3):
ptr2[i] = i
assert_equal(ptr2[1], 1)
memset
和memset_zero
:用给定值或0
填充内存
var ptr1 = stack_allocation[2, DType.float32]()
memset(ptr1, 1, 2)
print(ptr1.load[width=2]())
var ptr2 = stack_allocation[3, Int]()
memset_zero(ptr2, 3)
assert_equal(ptr2[1], 0)
# [2.3694278276172396e-38, 2.3694278276172396e-38]
这里用指定值填充有点怪,需要传入一个 SIMD[uint8, 1]
类型的值填充,但是返回的值却和填充值相差很大,不知为何。
memcmp
: 比较两个指针的字节缓冲区的大小
@value
struct Pair:
var left: Int
var right: Int
fn main() raises:
var pair1 = Pair(1, 2)
var pair2 = Pair(1, 2)
# pointer
var ptr1 = Pointer.address_of(pair1)
var ptr2 = Pointer.address_of(pair2)
# 0 表示相等,1 表示大于,-1 表示小于
var errors2 = memcmp(ptr1, ptr2, 1)
assert_equal(errors2, 0)
memcpy
: 复制数据
alias size = 2
var src = DTypePointer[DType.uint8]().alloc(size * 2)
var dst = DTypePointer[DType.uint8]().alloc(size * 2)
for i in range(size * 2):
dst[i] = 0
src[i] = 2
memcpy(dst, src, size)
print(dst.load[width=size*2]())