参与:朱乾树、黄小天 PyTorch 中的基本单位是张量(Tensor)。本文的主旨是如何在 PyTorch 中实现 Tensor 的概述,以便用户可从 Python shell 与之交互。本文主要回答以下四个主要问题: 1. PyTorch 如何通过扩展 Python 解释器来定义可以从 Python 代码中调用的 Tensor 类型? 2. PyTorch 如何封装实际定义 Tensor 属性和方法的 C 的类库? 3. PyTorch 的 C 类包装器如何生成 Tensor 方法的代码? 4. PyTorch 的编译系统如何编译这些组件并生成可运行的应用程序? 扩展 Python 解释器 PyTorch 定义了一个新的包 torch。本文中,我们将考虑._C 模块。这是一个用 C 编写的被称为「扩展模块」的 Python 模块,它允许我们定义新的内置对象类型(例如 Tensor)和调用 C / C ++函数。 ._C 模块定义在 torch/csrc/Module.cpp 文件中。init_C()/ PyInit__C()函数创建模块并根据需要添加方法定义。这个模块被传递给一些不同的__init()函数,这些函数会添加更多的对象到模块中,以及注册新的类型等。 __init() 可调用的部分函数如下: ASSERT_TRUE(THPDoubleTensor_init(module));ASSERT_TRUE(THPFloatTensor_init(module));ASSERT_TRUE(THPHalfTensor_init(module));ASSERT_TRUE(THPLongTensor_init(module));ASSERT_TRUE(THPIntTensor_init(module));ASSERT_TRUE(THPShortTensor_init(module));ASSERT_TRUE(THPCharTensor_init(module));ASSERT_TRUE(THPByteTensor_init(module)); 这些__init()函数将每种类型的 Tensor 对象添加到._C 模块,以便它们可以在._C 模块中调用。下面我们来了解这些方法的工作原理。 THPTensor 类型 PyTorch 很像底层的 TH 和 THC 类库,它定义了一个专门针对多种不同的类型数据的「通用」Tensor。在考虑这种专业化的工作原理之前,我们首先考虑如何在 Python 中定义新的类型,以及如何创建通用的 THPTensor 类型。 Python 运行时将所有 Python 对象都视为 PyObject * 类型的变量,PyObject * 是所有 Python 对象的「基本类型」。每个 Python 类型包含对象的引用计数,以及指向对象的「类型对象」的指针。类型对象确定类型的属性。例如,该对象可能包含一系列与类型相关联的方法,以及调用哪些 C 函数来实现这些方法。该对象还可能包含表示其状态所需的任意字段。 定义新类型的准则如下: 1. 创建一个结构体,它定义了新对象将包括的属性 2. 定义类型的类型对象 结构体本身可能十分简单。在 Python 中,实际上所有浮点数类型都是堆上的对象。Python float 结构体定义为: typedef struct { PyObject_HEAD double ob_fval;} PyFloatObject; PyObject_HEAD 是引入实现对象的引用计数的代码的宏,以及指向相应类型对象的指针。所以在这种情况下,要实现浮点数,所需的唯一其他「状态」是浮点值本身。 现在,我们来看看 THPTensor 类型的结构题: struct THPTensor { PyObject_HEAD THTensor *cdata;}; 很简单吧?我们只是通过存储一个指针来包装底层 TH 张量。关键部分是为新类型定义「类型对象」。我们的 Python 浮点数的类型对象的示例定义的形式如下: static PyTypeObject py_FloatType = { PyVarObject_HEAD_INIT(NULL, 0) "py.FloatObject", /* tp_name */ sizeof(PyFloatObject), /* tp_basicsize */ 0, /* tp_itemsize */ 0, /* tp_dealloc */ 0, /* tp_print */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_as_async */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ "A floating point number", /* tp_doc */}; 想象一个类型对象的最简单的方法就是定义一组该对象属性的字段。例如,tp_basicsize 字段设置为 sizeof(PyFloatObject)。这是为了让 Python 知道 PyFloatObject 调用 PyObject_New()时需要分配多少内存。你可以设置的字段的完整列表在 CPython 后端的 object.h 中定义:https://github.com/python/cpython/blob/master/Include/object.h. THPTensor 的类型对象是 THPTensorType,它定义在 csrc/generic/Tensor.cpp 文件中。该对象定义了 THPTensor 的类型名称、大小及映射方法等。 我们来看看我们在 PyTypeObject 中设置的 tp_new 函数: PyTypeObject THPTensorType = { PyVarObject_HEAD_INIT(NULL, 0) ... THPTensor_(pynew), /* tp_new */}; (责任编辑:本港台直播) |