Bootstrap

Mojo 学习 ——指针

Mojo 学习 —— 指针


这一节我们学习一下 Mojo 内存管理模块,这个包里面包含 4 个子模块。包括

  • memory:定义了操作内存的函数
  • reference:定义了 Reference 类型即操作
  • unsafe:不安全指针类型`
  • unsafe_pointer:通用不安全指针类型

unsafe

unsafe 模块下面定义了两种类型的指针结构:

  • LegacyPointer:定义包含一个寄存器可传递类型地址的结构。
  • DTypePointer:定义包含给定 dtype 类型地址的结构。

PointerLegacyPointer 的别名

LegacyPointer

主要用于存储和操作寄存器可传递类型(整数、浮点数和布尔值等)。例如

from testing import assert_equal, assert_true, assert_not_equal

fn main() raises:
    var nullptr = Pointer[Int]()
    assert_equal(str(nullptr), "0x0")
    var ptr = Pointer[Int].alloc(1)
    print(ptr)
    assert_true(str(ptr).startswith("0x"))
    assert_not_equal(str(ptr), "0x0")
    ptr.free()
# 0x56267776c010

assert_equal 断言两个对象相等,assert_true 断言条件为真,assert_not_equal 断言两个对象不相等

将指针转换为字符串时,将返回十六进制字符串,包含该指针目标内存位置的十六进制表示。

也可以使用静态方法构建一个空指针

var nullptr = Pointer[Int].get_null()

记住使用完记得调用 free 方法释放内存。

指针可以使用 [] 进行解引用

var ptr = Pointer[Int].alloc(1)
ptr[] = 42
assert_equal(ptr[], 42)
ptr.free()

将指针指向可变字符串

alias payload = "$Modular!Mojo!HelloWorld^"
var ptr = Pointer[String].alloc(1)
__get_address_as_uninit_lvalue(ptr.address) = String()
ptr[] = payload
assert_equal(ptr[], payload)
ptr.free()

其中 __get_address_as_uninit_lvalue = String() 代码返回一个未初始化的内存,并在未初始化的内存中初始化一个新值。

使用指针来指向自定义结构体

@value
struct Pair:
    var first: Int
    var second: Int

fn main() raises:
    var ptr = Pointer[Pair].alloc(1)
    ptr[].first = 42
    ptr[].second = 24
    
    assert_equal(ptr[].first, 42)
    assert_equal(ptr[].second, 24)
    ptr.free()

分别使用 storeload 方法来存储和访问指针指向的值

var size = 3
var ptr = Pointer[Int].alloc(size)
for i in range(size):
    ptr.store(i, i)
for i in range(size):
    print(ptr.load(i))
ptr.free()

此外,Pointer 还支持 nt_store,即非时效存储,指的是跳过处理器缓存,直接将数据写入主内存的操作,是一种用于优化大数据块内存写入操作的技术。

这种情况下要求地址必须正确对齐,即数据在内存中的起始地址必须是特定大小的倍数。比如 avx51264 字节, avx232 字节, avx16 字节。

使用 bitcast 可以转换指针指向值的类型,但转换有可能会损失精度。例如

var new_ptr = ptr.bitcast[Bool]()
for i in range(size):
    print(new_ptr.load(i))
new_ptr.free()
# False
# False
# False

可以使用 offset 让指针偏移,例如

var new_ptr = ptr
for _ in range(3):
    print(new_ptr[])
    new_ptr = new_ptr.offset(1)

也可以使用加法运算,让指针偏移。例如上面的代码与下面两种方式一样

for i in range(3):
    print((ptr + i)[])
    
for i in range(3):
    print(ptr[i])

这点有点类似 C 语言中的指针自增遍历指针指向的元素,可以使用加减以及大小比较等操作。

使用静态方法 address_of 可以直接获取对象的地址。例如

var pair = Pair(1, 2)
var ptr = Pointer.address_of(pair)
print(ptr[].first, ptr[].second)
# 1 2

DTypePointer

DTypePointer 结构体主要用于操作 DType 类型。基本上 Pointer 支持的操作 DTypePointer 都支持,还有一些额外的方法。

例如,定义空指针及一个长度为 1 的指针

var nullptr = DTypePointer[DType.float32]()
assert_equal(str(nullptr), "0x0")

var ptr = DTypePointer[DType.float32].alloc(1)
assert_true(str(ptr).startswith("0x"))
assert_not_equal(str(ptr), "0x0")
ptr.free()

DTypePointer 可以使用下标加索引的方式访问或设置指针对应位置的值。例如

alias size = 4
var a = DTypePointer[DType.int32].alloc(size)
for i in range(size):
    a[i] = -i
for i in range(size):
    print(a[i], end=' ')
a.free()

还可以使用 storeload 快速存储和获取 SIMD 向量。例如

a.store(0, SIMD[DType.int32, 4](9, 5, 2, 7))
print(a.load[width=4]())

类似地,可以使用 simd_nt_store 方法,以非时效存储的方式来存储 SIMD 向量。

DTypePointer 还支持以跨越的方式存储和加载 SIMD 向量。例如

var a = DTypePointer[DType.int32].alloc(size)
a.store(0, SIMD[DType.int32, 4](9, 5, 2, 7))
# strided store: [1, 5, 3, 7]
a.simd_strided_store[2](SIMD[DType.int32, 2](1, 3), 2)
# strided load [1, 7]
assert_equal(a.simd_strided_load[2](3), SIMD[DType.int32, 2](1, 7))

只要设置需要存储或读取的向量长度,以及每个向量元素之间的跨度(步长)即可。

还有两个更高级点的方法,用于修改值或获取值。

scatter

scatter 可以根据传入的偏移向量以及值向量来修改指针对应位置的值。其中偏移向量必须是整数,且长度必须是 2 的幂。例如

from testing import assert_almost_equal

fn main() raises:
    var ptr = DTypePointer[DType.float32].alloc(4)
    ptr.store(0, SIMD[ptr.type, 4](0.0))
    # 定义测试函数
    @parameter
    def _test_scatter[
        width: Int
    ](
        offset: SIMD[_, width],
        val: SIMD[ptr.type, width],
        desired: SIMD[ptr.type, 4],
    ):
        ptr.scatter(offset, val)
        var actual = ptr.load[width=4](0)
        assert_almost_equal(
            actual, desired, msg="_test_scatter", atol=0.0, rtol=0.0
        )
    # 将第三个值设置为 2
    _test_scatter[1](UInt16(2), 2.0, SIMD[ptr.type, 4](0.0, 0.0, 2.0, 0.0))
    # 将第二个值设置为 1,多次出现同一个位置取最后依次出现时的值
    _test_scatter(
        SIMD[DType.uint32, 4](1, 1, 1, 1),
        SIMD[ptr.type, 4](-1.0, 2.0, -2.0, 1.0),
        SIMD[ptr.type, 4](0.0, 1.0, 2.0, 0.0),
    )
    # 倒序设置值
    _test_scatter(
        SIMD[DType.uint64, 4](3, 2, 1, 0),
        SIMD[ptr.type, 4](0.0, 1.0, 2.0, 3.0),
        SIMD[ptr.type, 4](3.0, 2.0, 1.0, 0.0),
    )
    # 释放内存
    ptr.free()

这种方式可以很方便地修改数据中任意位置的值。还可以使用掩码的方式只修改掩码为 True 时对应位置的值。

from testing import assert_almost_equal

fn main() raises:
    var ptr = DTypePointer[DType.float32].alloc(4)
    ptr.store(0, SIMD[ptr.type, 4](0.0))
    # 测试函数
    @parameter
    def _test_masked_scatter[
        width: Int
    ](
        offset: SIMD[_, width],
        val: SIMD[ptr.type, width],
        mask: SIMD[DType.bool, width],
        desired: SIMD[ptr.type, 4],
    ):
        ptr.scatter(offset, val, mask)
        var actual = ptr.load[width=4](0)
        assert_almost_equal(
            actual, desired, msg="_test_masked_scatter", atol=0.0, rtol=0.0
        )
    # False,不修改值
    _test_masked_scatter[1](
        Int16(2), 2.0, False, SIMD[ptr.type, 4](0.0, 0.0, 0.0, 0.0)
    )
    # True,将位置 3 的值修改为 2
    _test_masked_scatter[1](
        Int32(2), 2.0, True, SIMD[ptr.type, 4](0.0, 0.0, 2.0, 0.0)
    )
    # 执行前三次修改,对同一个位置(1)修改了三次,结果为 -2
    _test_masked_scatter(  # Test with repeated offsets
        SIMD[DType.int64, 4](1, 1, 1, 1),
        SIMD[ptr.type, 4](-1.0, 2.0, -2.0, 1.0),
        SIMD[DType.bool, 4](True, True, True, False),
        SIMD[ptr.type, 4](0.0, -2.0, 2.0, 0.0),
    )
    # 偏移为倒序,不修改第二个值
    _test_masked_scatter(
        SIMD[DType.index, 4](3, 2, 1, 0),
        SIMD[ptr.type, 4](0.0, 1.0, 2.0, 3.0),
        SIMD[DType.bool, 4](True, False, True, True),
        SIMD[ptr.type, 4](3.0, 2.0, 2.0, 0.0),
    )

    ptr.free()

由于 SIMD 向量长度必须是 2 的幂,掩码方式的加入,可以让我们的修改更自由,可以不用全部修改。

gather

gather 方法可以根据传入的偏移向量来依次获取指针对应位置的值。例如

from testing import assert_equal, assert_almost_equal


fn main() raises:
    var ptr = DTypePointer[DType.float32].alloc(4)
    ptr.store(0, SIMD[ptr.type, 4](0.0, 1.0, 2.0, 3.0))
    # 定义测试函数
    @parameter
    def _test_gather[
        width: Int
    ](offset: SIMD[_, width], desired: SIMD[ptr.type, width]):
        var actual = ptr.gather(offset)
        assert_almost_equal(
            actual, desired, msg="_test_gather", atol=0.0, rtol=0.0
        )
    # 需要获取值的位置
    var offset = SIMD[DType.int64, 8](3, 0, 2, 1, 2, 0, 3, 1)
    # 提取值后的结果
    var desired = SIMD[ptr.type, 8](3.0, 0.0, 2.0, 1.0, 2.0, 0.0, 3.0, 1.0)
    # 测试
    _test_gather[1](UInt16(2), 2.0)
    _test_gather(offset.cast[DType.uint32]().slice[2](), desired.slice[2]())
    _test_gather(offset.cast[DType.uint64]().slice[4](), desired.slice[4]())
    # 释放内存
    ptr.free()

还可以使用掩码的方式,在掩码为 False 时,不取出 offset 对应位置的值,而是返回默认值,为 True 时,将 offset 对应位置的值取出。例如

from testing import assert_equal, assert_almost_equal

fn main() raises:
    var ptr = DTypePointer[DType.float32].alloc(4)
    ptr.store(0, SIMD[ptr.type, 4](0.0, 1.0, 2.0, 3.0))
    # 测试掩码的函数
    @parameter
    def _test_masked_gather[
        width: Int
    ](
        offset: SIMD[_, width],
        mask: SIMD[DType.bool, width],
        default: SIMD[ptr.type, width],
        desired: SIMD[ptr.type, width],
    ):
        var actual = ptr.gather(offset, mask, default)
        assert_almost_equal(
            actual, desired, msg="_test_masked_gather", atol=0.0, rtol=0.0
        )

    var offset = SIMD[DType.int64, 8](3, 0, 2, 1, 2, 0, 3, 1)
    var desired = SIMD[ptr.type, 8](3.0, 0.0, 2.0, 1.0, 2.0, 0.0, 3.0, 1.0)
    # 将 3 的位置设置为 False
    var mask = (offset >= 0) & (offset < 3)
    # 默认值为 -1
    var default = SIMD[ptr.type, 8](-1.0)
    # 使用掩码之后的值
    desired = SIMD[ptr.type, 8](-1.0, 0.0, 2.0, 1.0, 2.0, 0.0, -1.0, 1.0)
    # 测试
    _test_masked_gather[1](Int16(2), False, -1.0, -1.0)
    _test_masked_gather[1](Int32(2), True, -1.0, 2.0)
    _test_masked_gather(offset, mask, default, desired)

    ptr.free()

bitcast

该模块中还提供了一个 bitcast 函数,用于将一个 SIMD 值转换为另一个 SIMD 值。例如

assert_equal(
    bitcast[DType.int8, 8](SIMD[DType.int16, 4](1, 2, 3, 4)),
    SIMD[DType.int8, 8](1, 0, 2, 0, 3, 0, 4, 0),
)

assert_equal(
    bitcast[DType.int32, 1](SIMD[DType.int8, 4](0xFF, 0x00, 0xFF, 0x55)),
    Int32(1442775295),
)

unsafe_pointer

该模块实现了通用不安全指针类型。定义了 UnsafePointer 结构体来表示不安全类型的指针,以及 5 个操作该结构体指针的函数。

UnsafePointer

定义了一种可以指向任何可移动通用值的指针类型,其定义的方法与 Pointer 基本类似。例如

from testing import assert_equal

fn main() raises:
    var ptr = UnsafePointer[Int].alloc(1)
    ptr[0] = 0
    ptr[] += 1
    assert_equal(ptr[], 1)
    ptr.free()

定义、访问及解引用方式是一样的。

var ptr = UnsafePointer[Int].alloc(5)
for i in range(5):
    ptr[i] = i
for i in range(5):
    assert_equal(ptr[i], i)
ptr.free()

使用 address_of 静态方法

var local = 1
assert_not_equal(0, int(UnsafePointer[Int].address_of(local)))

将指针转换为不同的类型

var local = 1
var ptr = UnsafePointer[Int].address_of(local)
var aliased_ptr = ptr.bitcast[SIMD[DType.uint8, 4]]()

assert_equal(int(ptr), int(ptr.bitcast[Int]()))
assert_equal(int(ptr), int(aliased_ptr))

对指针进行运算

var p1 = UnsafePointer[Int].alloc(1)

assert_true((p1 - 1) < p1)
assert_true((p1 - 1) <= p1)
assert_true(p1 <= p1)
assert_true((p1 + 1) > p1)
assert_true((p1 + 1) >= p1)
assert_true(p1 >= p1)

p1.free()

函数使用

使用 destroy_pointee 函数来销毁指针指向的值,要求指针不为空。例如

var a = 1
var p = UnsafePointer[Int].address_of(a)
destroy_pointee(p)

可以直接将值拷贝或移动到指针所指向的位置上。例如

from testing import assert_not_equal, assert_equal

fn main() raises:

    var a = List(10, 2)

    var p1 = UnsafePointer[List[Int]].alloc(1)

    initialize_pointee_move(p1, a)
    # initialize_pointee_copy(p1, a)
    assert_equal(p1[0][0], 10)
    assert_not_equal(p1[0][1], 10)

反过来,可以将一个值从指针指向的位置中移出,并结束存储在该指针内存位置的值的生存期,随后对该指针的读取将无效。例如

var value = move_from_pointee(ptr)
assert_equal(value[0], 10)
assert_not_equal(value[1], 10)

如果使用 initialize_pointee_move 为指针存储了一个新的值,则可以重新从该指针处读取数据。

还可以在一个指针上的数据转移到另一个指针所指向的地址

from testing import assert_equal, assert_not_equal
from memory.unsafe_pointer import move_pointee

fn main() raises:

    var a = List(10, 2)

    var ptr1 = UnsafePointer[List[Int]].alloc(1)
    initialize_pointee_move(ptr1, a)

    var ptr2 = UnsafePointer[List[Int]].alloc(1)
    move_pointee(src=ptr1, dst=ptr2)
    assert_equal(ptr2[0][1], 2)

reference

这个模块中主要定义了两个结构体:

  • AddressSpace:定义了不同的地址空间,例如 0 表示通用型(generic)。
  • Reference:定义一个不为空的安全引用。

使用 Reference 创建一个引用

var a = 1
var ref = Reference(a)
assert_equal(ref[], a)
var ptr = ref.get_legacy_pointer()
ptr[] = 10
assert_equal(a, 10)

memory

该模块包含 5 个函数:

  • stack_allocation: 根据给定数据类型和元素个数,在栈上分配数据空间并返回一个指针。这种创建指针的方式不需要释放内存。例如
from testing import assert_equal, assert_almost_equal

fn main() raises:
    # DTypePointer
    var ptr1 = stack_allocation[2, DType.float32]()
    ptr1.store(SIMD[DType.float32, 2](3.14, 0.618))
    assert_almost_equal(ptr1[0], 3.14)
    # LegacyPointer
    var ptr2 = stack_allocation[3, Int]()
    for i in range(3):
        ptr2[i] = i
    assert_equal(ptr2[1], 1)
  • memsetmemset_zero:用给定值或 0 填充内存
var ptr1 = stack_allocation[2, DType.float32]()
    memset(ptr1, 1, 2)
    print(ptr1.load[width=2]())

    var ptr2 = stack_allocation[3, Int]()
    memset_zero(ptr2, 3)
    assert_equal(ptr2[1], 0)
# [2.3694278276172396e-38, 2.3694278276172396e-38]

这里用指定值填充有点怪,需要传入一个 SIMD[uint8, 1] 类型的值填充,但是返回的值却和填充值相差很大,不知为何。

  • memcmp: 比较两个指针的字节缓冲区的大小
@value
struct Pair:
    var left: Int
    var right: Int

fn main() raises:
    var pair1 = Pair(1, 2)
    var pair2 = Pair(1, 2)
    # pointer
    var ptr1 = Pointer.address_of(pair1)
    var ptr2 = Pointer.address_of(pair2)
    # 0 表示相等,1 表示大于,-1 表示小于
    var errors2 = memcmp(ptr1, ptr2, 1)
    assert_equal(errors2, 0)
  • memcpy: 复制数据
alias size = 2
var src = DTypePointer[DType.uint8]().alloc(size * 2)
var dst = DTypePointer[DType.uint8]().alloc(size * 2)
for i in range(size * 2):
    dst[i] = 0
    src[i] = 2

memcpy(dst, src, size)

print(dst.load[width=size*2]())
;