今天不想改bug 2023-12-28 16:11 采纳率: 0%
浏览 6
已结题

websocket客户端连接问题

写了一个websocket服务器,但是用websocket客户端在线测试连接后无法互相通信,不知道是什么问题

#include <iostream>
#include <string>
#include <map>
#include <set>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <cstring>
#include <thread>
#include <openssl/sha.h>
#include <regex>
#include <vector>

// WebSocket帧的结构
struct WebSocketFrame {
    bool fin;
    uint8_t opcode;
    bool mask;
    uint64_t payload_length;
    uint8_t masking_key[4];
    std::string payload;
};

// 解析WebSocket帧
WebSocketFrame parseWebSocketFrame(const std::string& data) {
    WebSocketFrame frame;
    const uint8_t* bytes = reinterpret_cast<const uint8_t*>(data.c_str());

    frame.fin = (bytes[0] & 0x80) != 0;
    frame.opcode = bytes[0] & 0x0F;
    frame.mask = (bytes[1] & 0x80) != 0;
    frame.payload_length = bytes[1] & 0x7F;

    size_t offset = 2;
    if (frame.payload_length == 126) {
        frame.payload_length = (bytes[2] << 8) | bytes[3];
        offset += 2;
    }
    else if (frame.payload_length == 127) {
        frame.payload_length = 0;
        for (int i = 0; i < 8; ++i) {
            frame.payload_length = (frame.payload_length << 8) | bytes[2 + i];
        }
        offset += 8;
    }

    if (frame.mask) {
        std::memcpy(frame.masking_key, bytes + offset, 4);
        offset += 4;
    }

    if (frame.payload_length > 0) {
        frame.payload = std::string(reinterpret_cast<const char*>(bytes + offset), frame.payload_length);
    }

    return frame;
}

// 构造WebSocket帧
std::string buildWebSocketFrame(const WebSocketFrame& frame) {
    std::string data;

    uint8_t byte1 = 0x80;  // 设置FIN位为1
    byte1 |= frame.opcode & 0x0F;
    data.push_back(byte1);

    uint8_t byte2 = frame.mask ? 0x80 : 0x00;  // 设置MASK位
    if (frame.payload_length < 126) {
        byte2 |= frame.payload_length;
        data.push_back(byte2);
    }
    else if (frame.payload_length <= 65535) {
        byte2 |= 126;
        data.push_back(byte2);
        data.push_back((frame.payload_length >> 8) & 0xFF);
        data.push_back(frame.payload_length & 0xFF);
    }
    else {
        byte2 |= 127;
        data.push_back(byte2);
        for (int i = 7; i >= 0; --i) {
            data.push_back((frame.payload_length >> (i * 8)) & 0xFF);
        }
    }

    if (frame.mask) {
        data.push_back(frame.masking_key[0]);
        data.push_back(frame.masking_key[1]);
        data.push_back(frame.masking_key[2]);
        data.push_back(frame.masking_key[3]);
    }

    if (!frame.payload.empty()) {
        data += frame.payload;
    }

    return data;
}

// Base64编码
std::string base64_encode(const std::string& data) {
    static const char* table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
    std::string output;

    unsigned char* bytes = reinterpret_cast<unsigned char*>(const_cast<char*>(data.c_str()));
    int length = data.length();

    for (int i = 0; i < length; i += 3) {
        uint32_t octet_a = i < length ? bytes[i] : 0;
        uint32_t octet_b = i + 1 < length ? bytes[i + 1] : 0;
        uint32_t octet_c = i + 2 < length ? bytes[i + 2] : 0;

        uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c;

        output.push_back(table[(triple >> 3 * 6) & 0x3F]);
        output.push_back(table[(triple >> 2 * 6) & 0x3F]);
        output.push_back(table[(triple >> 1 * 6) & 0x3F]);
        output.push_back(table[(triple >> 0 * 6) & 0x3F]);
    }

    if (length % 3 == 1) {
        output.resize(output.length() - 2);
        output[output.length() - 1] = '=';
    }
    else if (length % 3 == 2) {
        output.resize(output.length() - 1);
        output[output.length() - 1] = '=';
    }

    return output;
}

// SHA-1哈希
std::string sha1(const std::string& data) {
    unsigned char digest[SHA_DIGEST_LENGTH];
    SHA1(reinterpret_cast<const unsigned char*>(data.c_str()), data.size(), digest);

    std::string output;
    for (int i = 0; i < SHA_DIGEST_LENGTH; ++i) {
        output += std::to_string(static_cast<unsigned int>(digest[i]));
    }
    std::string encoded = base64_encode(output);
    return encoded;
}

// WebSocket服务器类
class WebSocketServer {
public:
    WebSocketServer(int port) : port_(port), running_(false) {}

    void start() {
        // 创建服务器套接字
        server_socket_ = socket(AF_INET, SOCK_STREAM, 0);
        if (server_socket_ == -1) {
            std::cerr << "Failed to create socket." << std::endl;
            return;
        }

        // 绑定服务器
        sockaddr_in server_addr;
        server_addr.sin_family = AF_INET;
        server_addr.sin_port = htons(port_);
        server_addr.sin_addr.s_addr = INADDR_ANY;
        if (bind(server_socket_, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) {
            std::cerr << "Failed to bind socket." << std::endl;
            close(server_socket_);
            return;
        }

        // 监听连接请求
        if (listen(server_socket_, 10) == -1) {
            std::cerr << "Failed to listen on socket." << std::endl;
            close(server_socket_);
            return;
        }

        std::cout << "WebSocket server started on port " << port_ << std::endl;
        running_ = true;

        while (running_) {
            sockaddr_in client_addr;
            socklen_t client_len = sizeof(client_addr);
            int client_socket = accept(server_socket_, (struct sockaddr *)&client_addr, &client_len);
            if (client_socket == -1) {
                std::cerr << "Failed to accept connection." << std::endl;
                continue;
            }

            std::thread t([this, client_socket]() {
                handleClient(client_socket);
            });
            t.detach();
        }

        close(server_socket_);
    }

    void stop() {
        running_ = false;
    }

private:
    void handleClient(int client_socket) {
        char buffer[4096];
        std::string http_header;
        std::regex r("Sec-WebSocket-Key: (.*)\r\n");
        std::match_results<std::string::const_iterator> m;

        // 接收HTTP握手请求
        int len = recv(client_socket, buffer, sizeof(buffer), 0);
        std::cout << len << std::endl;
        if (len <= 0) {
            std::cerr << "Failed to receive data from client." << std::endl;
            close(client_socket);
            return;
        }
        else
        {
            std::cout << "data from client." << std::endl;
        }

        http_header.append(buffer, len);
        len = 0;
        // 解析Sec-WebSocket-Key
        if (!std::regex_search(http_header, m, r) || m.size() != 2) {
            std::cerr << "Invalid WebSocket handshake." << std::endl;
            close(client_socket);
            return;
        }
        else
        {
            std::cout << "key." << std::endl;
        }

        std::string key = m[1];
        key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
            
        std::cout << "socket-key." << std::endl;


        std::string sha1_digest = sha1(key);
        std::cout << "sha1 digest." << sha1_digest << std::endl;


        std::string accept = base64_encode(sha1(key));
        std::cout << "Base64 encoded." << accept << std::endl;

        // 构造HTTP握手响应
        std::string http_response =
            "HTTP/1.1 101 Switching Protocols\r\n"
            "Upgrade: websocket\r\n"
            "Connection: Upgrade\r\n"
            "Sec-WebSocket-Accept: " + accept + "\r\n"
            "Sec-WebSocket-Version: 13\r\n"
            "Content - Type: text / plain; charset = UTF - 8\r\n\r\n"
        std::cout << "http response." << http_response<< std::endl;

        // 发送HTTP握手响应
        int sent = send(client_socket, http_response.data(), http_response.size(), 0);
        if (sent < 0) {
            std::cerr << "error send" << std::endl;
            return;
        }
        else
        {
            std::cout << "send scuess." << sent << std::endl;
        }
        memset(buffer,0,sizeof(buffer));
        // 接收WebSocket帧
        bool connected = true;
        std::string buffer2;
        while (connected) {
            std::cout << "rev loop iteration." << std::endl;
            len = recv(client_socket, buffer, sizeof(buffer), 0);
            std::cout << len << std::endl;
            if (len <= 0) {
            std::cout << "connection closed." << std::endl;
                connected = false;
                close(client_socket);
                //continue;
            }

            buffer2.append(buffer, len);

            // 解析WebSocket帧
            WebSocketFrame frame = parseWebSocketFrame(buffer2);
            if (frame.fin) {
                // 处理完整的WebSocket帧
                processWebSocketFrame(frame);
                buffer2.clear();
            }
        }
        std::cout << "close connection." << std::endl;
    }

    void processWebSocketFrame(const WebSocketFrame& frame) {
        // 处理接收到的WebSocket帧
        // ...

        // 回复数据
        WebSocketFrame responseFrame;
        responseFrame.fin = true;
        responseFrame.opcode = frame.opcode;
        responseFrame.mask = false;
        responseFrame.payload_length = 13;
        responseFrame.payload = "Hello, client!";
        std::string response = buildWebSocketFrame(responseFrame);
        send(client_socket, response.data(), response.size(), 0);
    }

private:
    int server_socket_;
    int port_;
    bool running_;
    int client_socket;
};

int main() {
    int port = 8765;
    WebSocketServer server(port);

    server.start();
    std::cout << " stop the server..." << std::endl;
    std::cin.get();

    server.stop();
    return 0;
}


  • 写回答

2条回答 默认 最新

  • 未来bbaa 2023-12-28 16:18
    关注

    您的代码中出现了几个问题,这些可能导致WebSocket通信失败:

    1. SHA1哈希函数问题sha1函数中转换成字符串的方式并不正确。您应该直接使用哈希后的字节数据进行Base64编码,而不是将每个字节转换为数字然后转换为字符串。转换为数字的字符串不能用来进行Base64编码。

    2. HTTP响应格式问题:您的HTTP响应头部有一些格式错误,例如"Content - Type: text / plain; charset = UTF - 8\r\n\r\n"中存在不必要的空格和一个多余的"\r\n"。这可能导致客户端无法正确地解析响应。

    3. WebSocket帧解析问题:当客户端发送分片消息时,您的代码可能无法正确处理。因为可能会收到不完整的帧,导致parseWebSocketFrame函数不能正确解析。

    4. 线程安全问题:在多线程环境中,client_socket是作为全局变量进行访问的,这可能会导致竞争条件。您应该将其作为参数传递给每个线程。

    下面是一些针对这些问题的修改建议:

    • 修改sha1函数实现,直接对哈希后的数据进行Base64编码:

      std::string sha1(const std::string& data) {
        unsigned char digest[SHA_DIGEST_LENGTH];
        SHA1(reinterpret_cast<const unsigned char*>(data.c_str()), data.size(), digest);
        return base64_encode(std::string(reinterpret_cast<char*>(digest), SHA_DIGEST_LENGTH));
      }
      
    • 修正HTTP响应格式,删除多余的空格和\r\n

      std::string http_response =
        "HTTP/1.1 101 Switching Protocols\r\n"
        "Upgrade: websocket\r\n"
        "Connection: Upgrade\r\n"
        "Sec-WebSocket-Accept: " + accept + "\r\n"
        "Sec-WebSocket-Version: 13\r\n\r\n";
      
    • 在解析WebSocket帧时,检查是否收到了足够的数据以避免不完整的帧解析。如果数据不完整,等待更多数据到达。

    • 使用互斥锁来确保client_socket的线程安全,或者更好的是,将其作为参数传递给每个处理客户端的线程。

    请注意,这些修改只是基于您提供的代码片段,实际上还有很多可能需要关注的细节,比如处理掩码的逻辑、处理关闭连接的逻辑、处理PING/PONG帧以及正确处理错误等。WebSocket协议有很多细节需要正确实现,以确保与各种WebSocket客户端的兼容性。如果问题仍然存在,您可能还需要使用调试工具检查数据包以及使用适当的WebSocket测试工具来帮助定位问题。

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 2月1日
  • 创建了问题 12月28日

悬赏问题

  • ¥30 关于用python写支付宝扫码付异步通知收不到的问题
  • ¥50 vue组件中无法正确接收并处理axios请求
  • ¥15 隐藏系统界面pdf的打印、下载按钮
  • ¥15 MATLAB联合adams仿真卡死如何解决(代码模型无问题)
  • ¥15 基于pso参数优化的LightGBM分类模型
  • ¥15 安装Paddleocr时报错无法解决
  • ¥15 python中transformers可以正常下载,但是没有办法使用pipeline
  • ¥50 分布式追踪trace异常问题
  • ¥15 人在外地出差,速帮一点点
  • ¥15 如何使用canvas在图片上进行如下的标注,以下代码不起作用,如何修改