WilL846 2023-04-17 21:42 采纳率: 75%
浏览 15
已结题

如何用tensorrt实现两个维度大小不同的张量点乘mul

请问有人知道怎么用tensorrt的api实现两个维度大小不同的张量点乘吗,比如a.shape=[64, 1, 1], b.shape=[64, 240, 320], 维度顺序是(C,H,W), a×b的shape是[64, 240, 320], 如何用tensorrt的api计算a×b,也就是pytorch里的torch.mul

  • 写回答

2条回答 默认 最新

  • Wali_yiwa59418 2023-04-18 00:42
    关注

    在TensorRT中,可以使用plugin来自定义计算算法,实现两个维度大小不同的张量的点乘操作。具体流程如下:

    1. 实现一个自定义的TensorRT插件,可以继承IPluginV2接口。在实现该插件时,需要定义插件输入和输出的数据格式(data format),以及插件需要的配置。

    2. 在插件的实现中,可以直接获取输入和输出tensor的指针,然后利用循环遍历的方式计算点乘操作。

    下面是一个实现的例子:

    #include "NvInfer.h"
    #include <cstdio>
    
    using namespace nvinfer1;
    
    class MultiplyPlugin : public IPluginV2
    {
    public:
        MultiplyPlugin() {}
        MultiplyPlugin(const void* data, size_t length)
        {
            const char *d = reinterpret_cast<const char*>(data), *a = d;
            mInputDims.nbDims = read<int>(d);
            for (int i = 0; i < mInputDims.nbDims; ++i)
                mInputDims.d[i] = read<int>(d);
            mOutputDims.nbDims = read<int>(d);
            for (int i = 0; i < mOutputDims.nbDims; ++i)
                mOutputDims.d[i] = read<int>(d);
            assert(d == a + length);
        }
        ~MultiplyPlugin() {}
    
        int getNbOutputs() const override { return 1; }
        Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { return mOutputDims; }
        bool supportsFormat(DataType type, PluginFormat format) const override { return (type == DataType::kFLOAT && format == PluginFormat::kLINEAR); }
        void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputs, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override
        {
            mDataType = type;
        }
        int initialize() override { return 0; }
        void terminate() override {}
        size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
        int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override
        {
            const float* input = reinterpret_cast<const float*>(inputs[0]);
            const float* weight = reinterpret_cast<const float*>(inputs[1]);
            float* output = reinterpret_cast<float*>(outputs[0]);
            const int inputSize = mInputDims.d[0];
            const int outputSize = mOutputDims.d[0] * mOutputDims.d[1] * mOutputDims.d[2];
            for (int n = 0; n < batchSize; ++n)
            {
                for (int c = 0; c < mOutputDims.d[0]; ++c)
                {
                    for (int h = 0; h < mOutputDims.d[1]; ++h)
                    {
                        for (int w = 0; w < mOutputDims.d[2]; ++w)
                        {
                            const int inputIndex = (n * inputSize) + c;
                            const int weightIndex = (n * outputSize) + (c * mOutputDims.d[1] * mOutputDims.d[2]) + (h * mOutputDims.d[2]) + w;
                            output[weightIndex] = input[inputIndex] * weight[weightIndex];
                        }
                    }
                }
            }
            return 0;
        }
        size_t getSerializationSize() const override
        {
            return sizeof(int)*(1 + mInputDims.nbDims + mOutputDims.nbDims);
        }
        void serialize(void* buffer) const override
        {
            char *d = reinterpret_cast<char*>(buffer), *a = d;
            write(d, mInputDims.nbDims);
            for (int i = 0; i < mInputDims.nbDims; ++i)
                write(d, mInputDims.d[i]);
            write(d, mOutputDims.nbDims);
            for (int i = 0; i < mOutputDims.nbDims; ++i)
                write(d, mOutputDims.d[i]);
            assert(d == a + getSerializationSize());
        }
        void destroy() override { delete this; }
        const char* getPluginType() const override { return "MultiplyPlugin"; }
        const char* getPluginVersion() const override { return "1.0"; }
        void setPluginNamespace(const char* pluginNamespace) override { mNameSpace = pluginNamespace; }
        const char* getPluginNamespace() const override { return mNameSpace.c_str(); }
    
    private:
        template<typename _T>
        static void write(char*& buffer, const _T& val)
        {
            *reinterpret_cast<_T*>(buffer) = val;
            buffer += sizeof(_T);
        }
        template<typename _T>
        static _T read(const char*& buffer)
        {
            _T val = *reinterpret_cast<const _T*>(buffer);
            buffer += sizeof(_T);
            return val;
        }
    
        DataType mDataType = DataType::kFLOAT;
        Dims mInputDims, mOutputDims;
        std::string mNameSpace;
    };
    
    class MultiplyPluginCreator : public IPluginCreator
    {
    public:
        MultiplyPluginCreator()
        {
            mPluginAttributes.emplace_back(PluginField("in_depth", nullptr, PluginFieldType::kINT32, 1));
            mPluginAttributes.emplace_back(PluginField("in_height", nullptr, PluginFieldType::kINT32, 1));
            mPluginAttributes.emplace_back(PluginField("in_width", nullptr, PluginFieldType::kINT32, 1));
            mPluginAttributes.emplace_back(PluginField("out_depth", nullptr, PluginFieldType::kINT32, 1));
            mPluginAttributes.emplace_back(PluginField("out_height", nullptr, PluginFieldType::kINT32, 1));
            mPluginAttributes.emplace_back(PluginField("out_width", nullptr, PluginFieldType::kINT32, 1));
        }
        ~MultiplyPluginCreator() {}
    
        const char* getPluginName() const override { return "MultiplyPlugin"; }
        const char* getPluginVersion() const override { return "1.0"; }
    
        const PluginFieldCollection* getFieldNames() override { return &mPluginAttributes; }
        IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override
        {
            const PluginField* fields = fc->fields;
            int inDepth = 1, inHeight = 1, inWidth = 1;
            int outDepth = 1, outHeight = 1, outWidth = 1;
            for (int i = 0; i < fc->nbFields; ++i)
            {
                if (!strcmp(fields[i].name, "in_depth"))
                    inDepth = *(int*)fields[i].data;
                if (!strcmp(fields[i].name, "in_height"))
                    inHeight = *(int*)fields[i].data;
                if (!strcmp(fields[i].name, "in_width"))
                    inWidth = *(int*)fields[i].data;
                if (!strcmp(fields[i].name, "out_depth"))
                    outDepth = *(int*)fields[i].data;
                if (!strcmp(fields[i].name, "out_height"))
                    outHeight = *(int*)fields[i].data;
                if (!strcmp(fields[i].name, "out_width"))
                    outWidth = *(int*)fields[i].data;
            }
            Dims inputDims = Dims3(inDepth, inHeight, inWidth);
            Dims outputDims = Dims3(outDepth, outHeight, outWidth);
            MultiplyPlugin* plugin = new MultiplyPlugin();
            plugin->setPluginNamespace(mNamespace.c_str());
            plugin->initialize();
            plugin->configureWithFormat(&inputDims, 1, &outputDims, 1, DataType::kFLOAT, PluginFormat::kLINEAR, 1);
            return plugin;
        }
        IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override
        {
            MultiplyPlugin* plugin = new MultiplyPlugin(serialData, serialLength);
            plugin->setPluginNamespace(mNamespace.c_str());
            return plugin;
        }
        void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; }
        const char* getPluginNamespace() const override { return mNamespace.c_str(); }
    
    private:
        std::string mNamespace;
        static PluginFieldCollection mPluginAttributes;
    };
    
    PluginFieldCollection MultiplyPluginCreator::mPluginAttributes;
    
    extern "C" IPluginCreator& getPluginCreator()
    {
        static MultiplyPluginCreator pluginCreator;
        return pluginCreator;
    }
    

    在上述代码中,自定义了一个名为MultiplyPlugin的插件,其中实现了自定义的点乘计算操作。该插件包含两个输入参数和一个输出参数,分别是输入张量、权重张量和输出张量。

    接下来,可以在TensorRT中使用该自定义插件来实现两个维度大小不同的张量点乘。

    // 创建Engine
    IBuilder* builder = createInferBuilder(gLogger);
    INetworkDefinition* network = builder->createNetworkV2(0U);
    ITensor* a = network->addInput("a", DataType::kFLOAT, Dims3(1, 1, 64));
    ITensor* b = network->addInput("b", DataType::kFLOAT, Dims3(320, 240, 64));
    ITensor* ab[] = {a, b};
    auto plugin = network->addPluginV2(ab, 2, createPlugin("MultiplyPlugin", pluginFactory));
    ITensor* output = plugin->getOutput(0);
    network->markOutput(*output);
    

    在创建Engine时,需要调用createPlugin函数来实例化自定义插件,并将两个输入张量作为参数添加到插件中。创建Engine后,就可以像其他TensorRT网络一样使用了。

    以上就是利用TensorRT的API实现两个大小不同张量点乘的步骤。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 4月27日
  • 已采纳回答 4月19日
  • 创建了问题 4月17日

悬赏问题

  • ¥15 python+selenium,在新增时弹出了一个输入框
  • ¥15 苹果验机结果的api接口哪里有??单次调用1毛钱及以下。
  • ¥20 学生成绩管理系统设计
  • ¥15 来一个cc穿盾脚本开发者
  • ¥15 CST2023安装报错
  • ¥15 使用diffusionbert生成文字 结果是PAD和UNK怎么办
  • ¥15 有人懂怎么做大模型的客服系统吗?卡住了卡住了
  • ¥20 firefly-rk3399上启动卡住了
  • ¥15 如何删除这个虚拟音频
  • ¥50 hyper默认的default switch