请问有人知道怎么用tensorrt的api实现两个维度大小不同的张量点乘吗,比如a.shape=[64, 1, 1], b.shape=[64, 240, 320], 维度顺序是(C,H,W), a×b的shape是[64, 240, 320], 如何用tensorrt的api计算a×b,也就是pytorch里的torch.mul
2条回答 默认 最新
- Wali_yiwa59418 2023-04-18 00:42关注
在TensorRT中,可以使用plugin来自定义计算算法,实现两个维度大小不同的张量的点乘操作。具体流程如下:
实现一个自定义的TensorRT插件,可以继承IPluginV2接口。在实现该插件时,需要定义插件输入和输出的数据格式(data format),以及插件需要的配置。
在插件的实现中,可以直接获取输入和输出tensor的指针,然后利用循环遍历的方式计算点乘操作。
下面是一个实现的例子:
#include "NvInfer.h" #include <cstdio> using namespace nvinfer1; class MultiplyPlugin : public IPluginV2 { public: MultiplyPlugin() {} MultiplyPlugin(const void* data, size_t length) { const char *d = reinterpret_cast<const char*>(data), *a = d; mInputDims.nbDims = read<int>(d); for (int i = 0; i < mInputDims.nbDims; ++i) mInputDims.d[i] = read<int>(d); mOutputDims.nbDims = read<int>(d); for (int i = 0; i < mOutputDims.nbDims; ++i) mOutputDims.d[i] = read<int>(d); assert(d == a + length); } ~MultiplyPlugin() {} int getNbOutputs() const override { return 1; } Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { return mOutputDims; } bool supportsFormat(DataType type, PluginFormat format) const override { return (type == DataType::kFLOAT && format == PluginFormat::kLINEAR); } void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputs, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override { mDataType = type; } int initialize() override { return 0; } void terminate() override {} size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { const float* input = reinterpret_cast<const float*>(inputs[0]); const float* weight = reinterpret_cast<const float*>(inputs[1]); float* output = reinterpret_cast<float*>(outputs[0]); const int inputSize = mInputDims.d[0]; const int outputSize = mOutputDims.d[0] * mOutputDims.d[1] * mOutputDims.d[2]; for (int n = 0; n < batchSize; ++n) { for (int c = 0; c < mOutputDims.d[0]; ++c) { for (int h = 0; h < mOutputDims.d[1]; ++h) { for (int w = 0; w < mOutputDims.d[2]; ++w) { const int inputIndex = (n * inputSize) + c; const int weightIndex = (n * outputSize) + (c * mOutputDims.d[1] * mOutputDims.d[2]) + (h * mOutputDims.d[2]) + w; output[weightIndex] = input[inputIndex] * weight[weightIndex]; } } } } return 0; } size_t getSerializationSize() const override { return sizeof(int)*(1 + mInputDims.nbDims + mOutputDims.nbDims); } void serialize(void* buffer) const override { char *d = reinterpret_cast<char*>(buffer), *a = d; write(d, mInputDims.nbDims); for (int i = 0; i < mInputDims.nbDims; ++i) write(d, mInputDims.d[i]); write(d, mOutputDims.nbDims); for (int i = 0; i < mOutputDims.nbDims; ++i) write(d, mOutputDims.d[i]); assert(d == a + getSerializationSize()); } void destroy() override { delete this; } const char* getPluginType() const override { return "MultiplyPlugin"; } const char* getPluginVersion() const override { return "1.0"; } void setPluginNamespace(const char* pluginNamespace) override { mNameSpace = pluginNamespace; } const char* getPluginNamespace() const override { return mNameSpace.c_str(); } private: template<typename _T> static void write(char*& buffer, const _T& val) { *reinterpret_cast<_T*>(buffer) = val; buffer += sizeof(_T); } template<typename _T> static _T read(const char*& buffer) { _T val = *reinterpret_cast<const _T*>(buffer); buffer += sizeof(_T); return val; } DataType mDataType = DataType::kFLOAT; Dims mInputDims, mOutputDims; std::string mNameSpace; }; class MultiplyPluginCreator : public IPluginCreator { public: MultiplyPluginCreator() { mPluginAttributes.emplace_back(PluginField("in_depth", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("in_height", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("in_width", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("out_depth", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("out_height", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("out_width", nullptr, PluginFieldType::kINT32, 1)); } ~MultiplyPluginCreator() {} const char* getPluginName() const override { return "MultiplyPlugin"; } const char* getPluginVersion() const override { return "1.0"; } const PluginFieldCollection* getFieldNames() override { return &mPluginAttributes; } IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override { const PluginField* fields = fc->fields; int inDepth = 1, inHeight = 1, inWidth = 1; int outDepth = 1, outHeight = 1, outWidth = 1; for (int i = 0; i < fc->nbFields; ++i) { if (!strcmp(fields[i].name, "in_depth")) inDepth = *(int*)fields[i].data; if (!strcmp(fields[i].name, "in_height")) inHeight = *(int*)fields[i].data; if (!strcmp(fields[i].name, "in_width")) inWidth = *(int*)fields[i].data; if (!strcmp(fields[i].name, "out_depth")) outDepth = *(int*)fields[i].data; if (!strcmp(fields[i].name, "out_height")) outHeight = *(int*)fields[i].data; if (!strcmp(fields[i].name, "out_width")) outWidth = *(int*)fields[i].data; } Dims inputDims = Dims3(inDepth, inHeight, inWidth); Dims outputDims = Dims3(outDepth, outHeight, outWidth); MultiplyPlugin* plugin = new MultiplyPlugin(); plugin->setPluginNamespace(mNamespace.c_str()); plugin->initialize(); plugin->configureWithFormat(&inputDims, 1, &outputDims, 1, DataType::kFLOAT, PluginFormat::kLINEAR, 1); return plugin; } IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override { MultiplyPlugin* plugin = new MultiplyPlugin(serialData, serialLength); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; } const char* getPluginNamespace() const override { return mNamespace.c_str(); } private: std::string mNamespace; static PluginFieldCollection mPluginAttributes; }; PluginFieldCollection MultiplyPluginCreator::mPluginAttributes; extern "C" IPluginCreator& getPluginCreator() { static MultiplyPluginCreator pluginCreator; return pluginCreator; }
在上述代码中,自定义了一个名为MultiplyPlugin的插件,其中实现了自定义的点乘计算操作。该插件包含两个输入参数和一个输出参数,分别是输入张量、权重张量和输出张量。
接下来,可以在TensorRT中使用该自定义插件来实现两个维度大小不同的张量点乘。
// 创建Engine IBuilder* builder = createInferBuilder(gLogger); INetworkDefinition* network = builder->createNetworkV2(0U); ITensor* a = network->addInput("a", DataType::kFLOAT, Dims3(1, 1, 64)); ITensor* b = network->addInput("b", DataType::kFLOAT, Dims3(320, 240, 64)); ITensor* ab[] = {a, b}; auto plugin = network->addPluginV2(ab, 2, createPlugin("MultiplyPlugin", pluginFactory)); ITensor* output = plugin->getOutput(0); network->markOutput(*output);
在创建Engine时,需要调用createPlugin函数来实例化自定义插件,并将两个输入张量作为参数添加到插件中。创建Engine后,就可以像其他TensorRT网络一样使用了。
以上就是利用TensorRT的API实现两个大小不同张量点乘的步骤。
本回答被题主选为最佳回答 , 对您是否有帮助呢?解决 无用评论 打赏 举报
悬赏问题
- ¥15 请问各位,如何在Jetson nano主控板的Ubuntu系统中安装PyQt5
- ¥15 MAC安装佳能LBP2900驱动的网盘提取码
- ¥400 微信停车小程序谁懂的来
- ¥15 ATAC测序到底用什么peak文件做Diffbind差异分析
- ¥15 安装ubantu过程中第一个vfat 文件挂载失败
- ¥20 GZ::CTF如何兼容一些靶机?
- ¥15 etcd集群部署问题
- ¥20 谁可以帮我一下问一下各位
- ¥15 为何重叠加权后love图的SMD与svyCreateTableOne函数绘制基线表的不一致
- ¥150 求 《小魔指》街机游戏机整合模拟软件