写了一个websocket服务器,但是用websocket客户端在线测试连接后无法互相通信,不知道是什么问题
#include <iostream>
#include <string>
#include <map>
#include <set>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <cstring>
#include <thread>
#include <openssl/sha.h>
#include <regex>
#include <vector>
// WebSocket帧的结构
struct WebSocketFrame {
bool fin;
uint8_t opcode;
bool mask;
uint64_t payload_length;
uint8_t masking_key[4];
std::string payload;
};
// 解析WebSocket帧
WebSocketFrame parseWebSocketFrame(const std::string& data) {
WebSocketFrame frame;
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(data.c_str());
frame.fin = (bytes[0] & 0x80) != 0;
frame.opcode = bytes[0] & 0x0F;
frame.mask = (bytes[1] & 0x80) != 0;
frame.payload_length = bytes[1] & 0x7F;
size_t offset = 2;
if (frame.payload_length == 126) {
frame.payload_length = (bytes[2] << 8) | bytes[3];
offset += 2;
}
else if (frame.payload_length == 127) {
frame.payload_length = 0;
for (int i = 0; i < 8; ++i) {
frame.payload_length = (frame.payload_length << 8) | bytes[2 + i];
}
offset += 8;
}
if (frame.mask) {
std::memcpy(frame.masking_key, bytes + offset, 4);
offset += 4;
}
if (frame.payload_length > 0) {
frame.payload = std::string(reinterpret_cast<const char*>(bytes + offset), frame.payload_length);
}
return frame;
}
// 构造WebSocket帧
std::string buildWebSocketFrame(const WebSocketFrame& frame) {
std::string data;
uint8_t byte1 = 0x80; // 设置FIN位为1
byte1 |= frame.opcode & 0x0F;
data.push_back(byte1);
uint8_t byte2 = frame.mask ? 0x80 : 0x00; // 设置MASK位
if (frame.payload_length < 126) {
byte2 |= frame.payload_length;
data.push_back(byte2);
}
else if (frame.payload_length <= 65535) {
byte2 |= 126;
data.push_back(byte2);
data.push_back((frame.payload_length >> 8) & 0xFF);
data.push_back(frame.payload_length & 0xFF);
}
else {
byte2 |= 127;
data.push_back(byte2);
for (int i = 7; i >= 0; --i) {
data.push_back((frame.payload_length >> (i * 8)) & 0xFF);
}
}
if (frame.mask) {
data.push_back(frame.masking_key[0]);
data.push_back(frame.masking_key[1]);
data.push_back(frame.masking_key[2]);
data.push_back(frame.masking_key[3]);
}
if (!frame.payload.empty()) {
data += frame.payload;
}
return data;
}
// Base64编码
std::string base64_encode(const std::string& data) {
static const char* table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string output;
unsigned char* bytes = reinterpret_cast<unsigned char*>(const_cast<char*>(data.c_str()));
int length = data.length();
for (int i = 0; i < length; i += 3) {
uint32_t octet_a = i < length ? bytes[i] : 0;
uint32_t octet_b = i + 1 < length ? bytes[i + 1] : 0;
uint32_t octet_c = i + 2 < length ? bytes[i + 2] : 0;
uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c;
output.push_back(table[(triple >> 3 * 6) & 0x3F]);
output.push_back(table[(triple >> 2 * 6) & 0x3F]);
output.push_back(table[(triple >> 1 * 6) & 0x3F]);
output.push_back(table[(triple >> 0 * 6) & 0x3F]);
}
if (length % 3 == 1) {
output.resize(output.length() - 2);
output[output.length() - 1] = '=';
}
else if (length % 3 == 2) {
output.resize(output.length() - 1);
output[output.length() - 1] = '=';
}
return output;
}
// SHA-1哈希
std::string sha1(const std::string& data) {
unsigned char digest[SHA_DIGEST_LENGTH];
SHA1(reinterpret_cast<const unsigned char*>(data.c_str()), data.size(), digest);
std::string output;
for (int i = 0; i < SHA_DIGEST_LENGTH; ++i) {
output += std::to_string(static_cast<unsigned int>(digest[i]));
}
std::string encoded = base64_encode(output);
return encoded;
}
// WebSocket服务器类
class WebSocketServer {
public:
WebSocketServer(int port) : port_(port), running_(false) {}
void start() {
// 创建服务器套接字
server_socket_ = socket(AF_INET, SOCK_STREAM, 0);
if (server_socket_ == -1) {
std::cerr << "Failed to create socket." << std::endl;
return;
}
// 绑定服务器
sockaddr_in server_addr;
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port_);
server_addr.sin_addr.s_addr = INADDR_ANY;
if (bind(server_socket_, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) {
std::cerr << "Failed to bind socket." << std::endl;
close(server_socket_);
return;
}
// 监听连接请求
if (listen(server_socket_, 10) == -1) {
std::cerr << "Failed to listen on socket." << std::endl;
close(server_socket_);
return;
}
std::cout << "WebSocket server started on port " << port_ << std::endl;
running_ = true;
while (running_) {
sockaddr_in client_addr;
socklen_t client_len = sizeof(client_addr);
int client_socket = accept(server_socket_, (struct sockaddr *)&client_addr, &client_len);
if (client_socket == -1) {
std::cerr << "Failed to accept connection." << std::endl;
continue;
}
std::thread t([this, client_socket]() {
handleClient(client_socket);
});
t.detach();
}
close(server_socket_);
}
void stop() {
running_ = false;
}
private:
void handleClient(int client_socket) {
char buffer[4096];
std::string http_header;
std::regex r("Sec-WebSocket-Key: (.*)\r\n");
std::match_results<std::string::const_iterator> m;
// 接收HTTP握手请求
int len = recv(client_socket, buffer, sizeof(buffer), 0);
std::cout << len << std::endl;
if (len <= 0) {
std::cerr << "Failed to receive data from client." << std::endl;
close(client_socket);
return;
}
else
{
std::cout << "data from client." << std::endl;
}
http_header.append(buffer, len);
len = 0;
// 解析Sec-WebSocket-Key
if (!std::regex_search(http_header, m, r) || m.size() != 2) {
std::cerr << "Invalid WebSocket handshake." << std::endl;
close(client_socket);
return;
}
else
{
std::cout << "key." << std::endl;
}
std::string key = m[1];
key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
std::cout << "socket-key." << std::endl;
std::string sha1_digest = sha1(key);
std::cout << "sha1 digest." << sha1_digest << std::endl;
std::string accept = base64_encode(sha1(key));
std::cout << "Base64 encoded." << accept << std::endl;
// 构造HTTP握手响应
std::string http_response =
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: " + accept + "\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Content - Type: text / plain; charset = UTF - 8\r\n\r\n"
std::cout << "http response." << http_response<< std::endl;
// 发送HTTP握手响应
int sent = send(client_socket, http_response.data(), http_response.size(), 0);
if (sent < 0) {
std::cerr << "error send" << std::endl;
return;
}
else
{
std::cout << "send scuess." << sent << std::endl;
}
memset(buffer,0,sizeof(buffer));
// 接收WebSocket帧
bool connected = true;
std::string buffer2;
while (connected) {
std::cout << "rev loop iteration." << std::endl;
len = recv(client_socket, buffer, sizeof(buffer), 0);
std::cout << len << std::endl;
if (len <= 0) {
std::cout << "connection closed." << std::endl;
connected = false;
close(client_socket);
//continue;
}
buffer2.append(buffer, len);
// 解析WebSocket帧
WebSocketFrame frame = parseWebSocketFrame(buffer2);
if (frame.fin) {
// 处理完整的WebSocket帧
processWebSocketFrame(frame);
buffer2.clear();
}
}
std::cout << "close connection." << std::endl;
}
void processWebSocketFrame(const WebSocketFrame& frame) {
// 处理接收到的WebSocket帧
// ...
// 回复数据
WebSocketFrame responseFrame;
responseFrame.fin = true;
responseFrame.opcode = frame.opcode;
responseFrame.mask = false;
responseFrame.payload_length = 13;
responseFrame.payload = "Hello, client!";
std::string response = buildWebSocketFrame(responseFrame);
send(client_socket, response.data(), response.size(), 0);
}
private:
int server_socket_;
int port_;
bool running_;
int client_socket;
};
int main() {
int port = 8765;
WebSocketServer server(port);
server.start();
std::cout << " stop the server..." << std::endl;
std::cin.get();
server.stop();
return 0;
}