需要debug,代码在添加ssl加密传输功能之前可以实现,添加后不能运行出来,不知道是不是逻辑有问题,求解答
客户端代码如下
import socket
import threading
import sys
from gmssl import sm2, func
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
import hashlib
# 服务器 IP 地址和端口号
SERVER_HOST = '127.0.0.1'
SERVER_PORT = 8000
# 创建 socket 对象
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 连接服务器
client_socket.connect((SERVER_HOST, SERVER_PORT))
# 用于加密的密钥
#key = b'secret_key'
# 用于服务器身份认证的 token
token = 'server_token'
# 用于用户身份认证的 token
user_token = ''
def create_key(data):#生成对称加密密钥,长度16的字符串
m= hashlib.md5()
m.update(data.encode("utf-8"))
n=m.hexdigest()
return n[0:16]
public_key=client_socket.recv(1024*1024).decode()#连接后接收服务器的公钥
private_key = ''
sm2_crypt = sm2.CryptSM2(public_key=public_key, private_key=private_key)
# 对明文进行补足
def pad_text(text):
while len(text) % 16 != 0:
text += ' '.encode('utf-8')
return text
# 去除补足的空格和换行
def unpad_text(text):
return text.rstrip()
def receive_server_messages():
"""
接收服务器发送的消息
"""
while True:
try:
public_key=client_socket.recv(1024*1024).decode()#连接后接收服务器的公钥
private_key = ''
sm2_crypt = sm2.CryptSM2(public_key=public_key, private_key=private_key)
message = client_socket.recv(1024*1024).decode()
print(message)
except Exception as e:
print(f'Error: {e}')
sys.exit()
def send_messages():
"""
发送消息给服务器
"""
while True:
try:
username = input('Enter your username: ')
#client_socket.send(username.encode())
inputData = input();
if(inputData=="quit"):
print("退出成功")
break
if(inputData==''):
print("请输入信息")
continue
###加密
key=create_key(inputData).encode()#生成对称密钥
iv = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
crypt_sm4 = CryptSM4()
crypt_sm4.set_key(key, SM4_ENCRYPT)
encrypt_username = crypt_sm4.crypt_ecb(username.encode()) # 加密用户名
encrypt_data = crypt_sm4.crypt_ecb(inputData.encode()) # 加密信息
enc_key = sm2_crypt.encrypt(key)#对密钥进行加密
client_socket.send(encrypt_username)
client_socket.send(encrypt_data)
client_socket.send(enc_key)
#if "quit" ==message:
# break
except Exception as e:
print(f'Error: {e}')
client_socket.close()
sys.exit()
def start_client():
"""
启动客户端
"""
# 请求用户输入用户名
#username = input('Enter your username: ')
#client_socket.send(username.encode())
# 启动接收消息和发送消息的线程
receive_thread = threading.Thread(target=receive_server_messages)
send_thread = threading.Thread(target=send_messages)
receive_thread.start()
send_thread.start()
if __name__ == '__main__':
start_client()
服务器代码
import socket
import threading
from gmssl import sm2, func
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
import hashlib
# 服务器 IP 地址和端口号
SERVER_HOST = '127.0.0.1'
SERVER_PORT = 8000
private_key = '00B9AB0B828FF68872F21A837FC303668428DEA11DCD1B24429D0C99E24EED83D5'
public_key = 'B9C9A6E04E9C91F7BA880429273747D7EF5DDEB0BB2FF6317EB00BEF331A83081A6994B8993F3F5D6EADDDB81872266C87C018FB4162F5AF347B483E24620207'
sm2_crypt = sm2.CryptSM2(public_key=public_key, private_key=private_key)
# 存储客户端信息的字典,格式为 {client_socket: (username, [group1, group2])}
clients = {}
# 存储群组信息的字典,格式为 {group_name: [client_socket1, client_socket2, ...]}
groups = {}
# 用于加密的密钥
#key = b'secret_key'
# 用于服务器身份认证的 token
token = 'server_token'
# 用于用户身份认证的 token
user_token = {}
# 锁
lock = threading.Lock()
# 对明文进行补足
def pad_text(text):
while len(text) % 16 != 0:
text += ' '.encode('utf-8')
return text
#去除补足的空格和换行
def unpad_text(text):
return text.rstrip()
def send_to_all_clients(message, sender_socket):
"""
将消息和公钥发送给所有客户端,除了发送者
"""
for client_socket in clients:
if client_socket != sender_socket:
client_socket.send(message)
client_socket.send(public_key.encode())#发送公钥
def handle_client_connection(client_socket):
"""
处理客户端连接
"""
# 请求客户端输入用户名
username = client_socket.recv(1024).decode().strip()
# 将客户端信息存储到 clients 字典中
print(username)
clients[client_socket] = (username, [])
while True:
# 接收客户端发送的消息
print("等待接收客户端消息")
message = client_socket.recv(1024*1024) #服务器接收数据.decode().strip()接收客户端套接字,最多接收 1024 字节,并将其解码为字符串类型
enc_key = client_socket.recv(1024*1024)
#message.decryte()
###数据解密
dec_key =sm2_crypt.decrypt(enc_key)#解出密钥
print("得到密钥:"+dec_key.decode())
iv = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # bytes类型
crypt_sm4 = CryptSM4()
crypt_sm4.set_key(dec_key, SM4_DECRYPT)
dec_message = crypt_sm4.crypt_ecb(message) #解出数据
print("解密数据成功")
print(dec_message)
###
# # 处理用户注册
# if message.startswith('REGISTER'):
# _, username = message.split(' ')
# clients[client_socket] = (username, [])
# client_socket.send('You have successfully registered.'.encode())
# 处理私聊消息
if message.startswith('PRIVATE'):
split_message = message.split(' ', 2)
if len(split_message) >= 3:
_, recipient_name, message_content = split_message
for client_socket, (usernames, _) in clients.items():
if usernames == recipient_name:
client_socket.send(f'{username} (private): {message_content}'.encode())
break
else:
client_socket.send(f'User {recipient_name} is not online.'.encode())
else:
client_socket.send('Invalid private message format. Usage: PRIVATE [recipient] [message]'.encode())
# 处理创建群组
elif message.startswith('CREATE'):
_, group_name = message.split(' ')
if group_name not in groups:
groups[group_name] = [client_socket]
clients[client_socket][1].append(group_name)
client_socket.send(f'Group {group_name} has been created.'.encode())
else:
client_socket.send(f'Group {group_name} already exists.'.encode())
# 处理加入群组
elif message.startswith('JOIN'):
_, group_name = message.split(' ')
if group_name in groups:
groups[group_name].append(client_socket)
clients[client_socket][1].append(group_name)
client_socket.send(f'You have joined group {group_name}.'.encode())
else:
client_socket.send(f'Group {group_name} does not exist.'.encode())
# 处理发送群组消息
elif message.startswith('GROUP'):
_, group_name, message_content = message.split(' ', 2)
if group_name in groups:
for group_member_socket in groups[group_name]:
if group_member_socket != client_socket:
group_member_socket.send(f'{username} ({group_name}): {message_content}'.encode())
else:
client_socket.send(f'Group {group_name} does not exist.'.encode())
# 处理离开群组
elif message.startswith('LEAVE'):
_, group_name = message.split(' ')
if group_name in groups:
if client_socket in groups[group_name]:
groups[group_name].remove(client_socket)
clients[client_socket][1].remove(group_name)
client_socket.send(f'You have left group {group_name}.'.encode())
else:
client_socket.send(f'You are not a member of group {group_name}.'.encode())
else:
client_socket.send(f'Group {group_name} does not exist.'.encode())
# 处理退出聊天室
elif message == 'quit':
print(f'{username} has left the chatroom.', client_socket)
client_socket.send('You are disconnected from the server.'.encode())
client_socket.close()
del clients[client_socket]
send_to_all_clients(f'{username} has left the chatroom.', client_socket)
break
# 处理群组列表
elif message == 'GROUPS':
group_list = ', '.join(groups.keys())
if group_list:
client_socket.send(f'Available groups: {group_list}'.encode())
else:
client_socket.send('No groups available.'.encode())
# 处理无法识别的命令
else:
client_socket.send('Invalid command.'.encode())
continue
def start_server():
"""
启动服务器
"""
# 创建 socket 对象
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 绑定 IP 地址和端口号
server_socket.bind((SERVER_HOST, SERVER_PORT)) #bind()指定本地地址
# 开始监听连接
server_socket.listen()
print(f'Server is running on {SERVER_HOST}:{SERVER_PORT}')
while True:
# 接受客户端连接请求
client_socket, client_address = server_socket.accept()
print(f'New connection from {client_address}')
# 处理客户端连接
client_thread = threading.Thread(target=handle_client_connection, args=(client_socket,))
client_thread.start()
if __name__ == '__main__':
start_server()