基于maven的构建环境:(pom.xml添加netty依赖)
<!--集成netty-->
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.32.Final</version>
</dependency>
服务构建主类:
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
* @Author: geyingke
* @Date: 2020/7/20
* @Class: NettyServer
* @Discription: TODO
**/
public class NettyServer {
private Logger logger = LogManager.getLogger(NettyServer.class);
private final int port;
public NettyServer(int port) {
this.port = port;
}
public void start() throws InterruptedException {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup group = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap
.group(bossGroup, group)
.channel(NioServerSocketChannel.class)
.localAddress(port)
//设置server初始化类,在初始化是判断响应的协议,分配到不同的ChannelHandler
.childHandler(new NettyServerInitializer());
ChannelFuture channelFuture = serverBootstrap.bind().sync();
logger.info(String.format("Netty server started!!!! port: %d", port));
channelFuture.channel().closeFuture().sync();
} catch (Exception e) {
group.shutdownGracefully().sync();
bossGroup.shutdownGracefully().sync();
} finally {
group.shutdownGracefully().sync();
bossGroup.shutdownGracefully().sync();
}
}
}
server初始化类
- 在初始化时,如果要兼容处理socket请求,socket的处Handler和相应的编码器必须在初始化的时候完成。目前仍在研究如何在一个handler处理两种类型的协议。
- 如果socket和websocket的Handler处理类不分开处理,websocket的捂手连接不能正常完成,目前正在寻找原因
- 当前实现tcp的粘包解决方案不使用netty提供的三种解决方案,连接方为c++程序,无包头标记码,因此循环截取bytebuff中的byte数组信息
import com.galaxyeye.icservice.im.parser.SocketUtils;
import com.galaxyeye.icservice.im.socket.NettySocketHandler;
import com.galaxyeye.icservice.im.webSocket.WebSocketHandler;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.stereotype.Component;
import java.lang.invoke.MethodHandles;
import java.util.List;
/**
* @Author: geyingke
* @Date: 2020/7/21
* @Class: NettyServerInitializer
* @Discription: TODO
**/
@Component
public class NettyServerInitializer extends ChannelInitializer<SocketChannel> {
private Logger logger = LogManager.getLogger(MethodHandles.lookup().lookupClass());
@Override
protected void initChannel(SocketChannel socketChannel) throws Exception {
//channel初始化
socketChannel.pipeline().addLast(new IdleStateHandler(60 * 2, 0, 0));
/**
* 注意:
* 1、netty兼容socket和websocket时,socket的响应处理必须在初始化时完成,否则socket消息后续处理失败
* 2、SocketParser用户鉴别websocket和socket,和socket消息粘包
* 3、如果同时兼容websocket和socket,socket消息的解码和编码需要在消息处理中进行,在pipeline后添加编码和解码器
*/
socketChannel.pipeline().addLast("SocketParser", new SocketParser());
socketChannel.pipeline().addLast(new NettySocketHandler());
}
private class SocketParser extends ByteToMessageDecoder {
/**
* WebSocket握手的协议前缀
*/
private static final String WEBSOCKET_PREFIX = "GET /";
private final Integer BASE_LENGTH = 14;
int beginIndex = 0;
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
String protocol = getBufStart(in);
in.resetReaderIndex();
if (protocol.startsWith(WEBSOCKET_PREFIX)) {
//websocket协议本身是基于http协议的,所以这边也要使用http解编码器
ctx.pipeline().addLast(new HttpServerCodec());
//以块的方式来写的处理器
ctx.pipeline().addLast(new ChunkedWriteHandler());
ctx.pipeline().addLast(new HttpObjectAggregator(8192));
ctx.pipeline().addLast(new WebSocketHandler());
ctx.pipeline().addLast(new WebSocketServerProtocolHandler("/ws", null, true, 65536 * 10));
//去除socket处理
ctx.pipeline().remove(NettySocketHandler.class);
ctx.pipeline().remove(this.getClass());
} else {
ByteBuf byteBuf = in.readerIndex(beginIndex);
int readableBytes = byteBuf.readableBytes();
if (readableBytes >= BASE_LENGTH) {
// 防止socket字节流攻击
// 防止,客户端传来的数据过大
// 因为,太大的数据,是不合理的
if (readableBytes > 2048) {
byteBuf.skipBytes(readableBytes);
}
while (byteBuf.readableBytes() > 0) {
int thisReadableBytes = byteBuf.readableBytes();
byte[] bytes = new byte[thisReadableBytes];
byteBuf.readBytes(bytes);
logger.info("send body: " + new String(bytes));
// 消息的长度
int length = SocketUtils.read_int_le(bytes, 0);
logger.info("readableBytes: " + readableBytes + "\t custom decode msg length: " + length);
// 判断请求数据包数据是否到齐
if (thisReadableBytes < length) {
// 还原读指针
in = byteBuf.readerIndex(beginIndex);
return;
}
byteBuf.resetReaderIndex();
//bytebuff在封装tcp流信息时,前面会多加4位,作为整个消息的长度
byte[] msgBytes = new byte[length + 4];
byteBuf.readBytes(msgBytes);
beginIndex = byteBuf.readerIndex();
String parse = SocketUtils.parse(msgBytes);
logger.info("full bag body: " + parse);
out.add(Unpooled.copiedBuffer(msgBytes));
byteBuf.markReaderIndex();
}
beginIndex = 0;
}
}
}
private String getBufStart(ByteBuf in) {
int length = in.readableBytes();
// 标记读位置
in.markReaderIndex();
byte[] content = new byte[length];
in.readBytes(content);
return new String(content);
}
}
}
Channel连接池实体,存储连接信息:channel不可序列化,因此不可存储到redis缓存中,连接的分布式共享不能基于简单的流共享方式。当前实现channel存储到本地缓存
import com.alibaba.fastjson.JSON;
import com.galaxyeye.icservice.conf.SpringContextBean;
import com.galaxyeye.icservice.utils.RedisTempleUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.util.Assert;
import java.lang.invoke.MethodHandles;
import java.util.*;
/**
* @Author: geyingke
* @Date: 2020/7/20
* @Class: MyChannelHandlePool
* @Discription: TODO
**/
public class MyChannelHandlePool {
public static Map<String, Channel> CHANNEL_MAP = new HashMap<>();
public static List<Channel> AUTH_CHANNEL = new ArrayList<>();
/**
* 通道及索引信息存储流程:
* 1、客户端通道建立:
* websocket:存入(CHANNEL_GROUP-channel)-->(id@appId-channelId)-->(CHANNEL_KEY-channelId)-->(USER_KEY:appId-id@appId)-->WS_ONLINE_APP_ID
* socket:(CHANNEL_GROUP-channel)-->(CHANNEL_KEY-channelId)
* 2、socket用户通道关系建立在接入会话时完成:(id@appId-channelId)-->(USER_KEY:appId-id@appId)
*/
public MyChannelHandlePool() {
}
public static RedisTempleUtil redisTempleUtil = SpringContextBean.getBean(RedisTempleUtil.class);
private static Logger logger = LogManager.getLogger(MethodHandles.lookup().lookupClass());
/**
* 通道组
*/
public static final String CHANNEL_GROUP = "CHANNEL_GROUP";
//离线的websocket记录
public static final String OUTLINE_CACHE = "OUTLINE_CACHE";
public static final String USER_KEY = "USER_KEY";
public static final String CHANNEL_KEY = "CHANNEL_KEY";
//排队队伍
public static final String QUEUE_MARK = "QUEUE_MARK";
//排队总的appId索引标识
public static final String QUEUE_APP_ID = "QUEUE_APP_ID";
//通道类型:websocket
public static final Integer WS_CHANNEL_TYPE = 1;
//通道类型:websocket
public static final Integer SOCKET_CHANNEL_TYPE = 2;
//websocket在线appId索引标识
public static final String WS_ONLINE_APP_ID = "WS_ONLINE_APP_ID";
//socket在线appId索引标识
public static final String SOCKET_ONLINE_APP_ID = "SOCKET_ONLINE_APP_ID";
//key分割符
public static final String KEY_SPLIT = ":";
/**
* 服务器重启,清除所有channel连接通道
*
* @return
*/
public static Long clearAllChannel() {
logger.info("============================服务启动初始化,清空所有连接相关数据===========================");
logger.info("============================(不清除排队队列,保证服务重启后,排队能够正常进行)===========================");
//清空连接通道组
// redisTempleUtil.del(CHANNEL_GROUP);
//删除ws用户索引下的所有关系数据
Set<Object> wsAppIdKeySet = redisTempleUtil.sGet(WS_ONLINE_APP_ID);
if (wsAppIdKeySet != null && wsAppIdKeySet.size() > 0) {
logger.info("待清空的websocket在线appId索引:" + JSON.toJSONString(wsAppIdKeySet));
for (Object o : wsAppIdKeySet) {
String appId = (String) o;
Set<Object> userKeys = redisTempleUtil.sGet(generateUserKey(appId, WS_CHANNEL_TYPE));
if (userKeys != null && userKeys.size() > 0) {
logger.info("待清空的websocket用户索引" + JSON.toJSONString(userKeys));
redisTempleUtil.del(userKeys.toArray(new String[userKeys.size()]));
}
}
redisTempleUtil.del(WS_ONLINE_APP_ID);
}
Set<Object> socketAppIdKeySet = redisTempleUtil.sGet(SOCKET_ONLINE_APP_ID);
if (socketAppIdKeySet != null && socketAppIdKeySet.size() > 0) {
logger.info("待清空的socket在线appId索引:" + JSON.toJSONString(socketAppIdKeySet));
for (Object o : socketAppIdKeySet) {
String appId = (String) o;
Set<Object> userKeys = redisTempleUtil.sGet(generateUserKey(appId, SOCKET_CHANNEL_TYPE));
if (userKeys != null && userKeys.size() > 0) {
logger.info("待清空的socket用户索引" + JSON.toJSONString(userKeys));
redisTempleUtil.del(userKeys.toArray(new String[userKeys.size()]));
}
}
redisTempleUtil.del(SOCKET_ONLINE_APP_ID);
}
//删除通道索引下的所有关系数据
Set<Object> channelKeys = redisTempleUtil.sGet(CHANNEL_KEY);
logger.info("待清空的通道id索引" + JSON.toJSONString(channelKeys));
redisTempleUtil.del(channelKeys.toArray(new String[channelKeys.size()]));
//清空索引
redisTempleUtil.del(CHANNEL_KEY);
logger.info("====================================连接数据初始化完成===================================");
return 1L;
}
/**
* 确认组中是否存在以channelId为key的值
*
* @param channelId
* @return
*/
public static boolean hasChannel(ChannelId channelId) {
return CHANNEL_MAP.containsKey(getStrChannelId(channelId));
}
/**
* 记录连接的通道
*
* @param channelId
* @param channel
* @return
*/
public static boolean addChannel(ChannelId channelId, Channel channel) {
boolean res = true;
// boolean hset = redisTempleUtil.hset(CHANNEL_GROUP, getStrChannelId(channelId), channel);
String strChannelId = getStrChannelId(channelId);
if (!CHANNEL_MAP.containsKey(strChannelId)) {
CHANNEL_MAP.putIfAbsent(strChannelId, channel);
} else {
res = false;
}
if (res) {
//创建通道id key索引
if (!redisTempleUtil.sHasKey(CHANNEL_KEY, strChannelId)) {
redisTempleUtil.sSet(CHANNEL_KEY, strChannelId);
}
}
return res;
}
/**
* 获取当前所有连接通道的数量
*
* @return
*/
public static Long getChannelGroupSize() {
return Long.valueOf(CHANNEL_MAP.size());
}
/**
* 将ChannelId转为string
*
* @param channelId
* @return
*/
public static String getStrChannelId(ChannelId channelId) {
return channelId.asLongText();
}
/**
* 删除连接的通道
*
* @param channel
* @return
*/
public static boolean delChannel(Channel channel) {
return CHANNEL_MAP.remove(getStrChannelId(channel.id()), channel);
}
/**
* 根据channelId获取通道
*
* @param channelId
* @return
*/
public static Channel getChannel(String channelId) {
return (Channel) CHANNEL_MAP.get(channelId);
}
/**
* 创建用户id和通道的关系,1:n
*
* @param userAppInfo
* @param channelId
* @return
*/
public static boolean createUserChannelRelation(String appId, String userAppInfo, ChannelId channelId, Integer channelType) {
//清除离线缓存
delOutlineQueue(userAppInfo);
//添加新的关系
if (redisTempleUtil.sHasKey(userAppInfo, getStrChannelId(channelId))) {
return true;
} else {
//创建用户-通道id关系
long l = redisTempleUtil.sSet(userAppInfo, getStrChannelId(channelId));
//创建用户-通道id的key索引
if (channelType == WS_CHANNEL_TYPE) {
if (!redisTempleUtil.sHasKey(WS_ONLINE_APP_ID, appId)) {
redisTempleUtil.sSet(WS_ONLINE_APP_ID, appId);
}
if (!redisTempleUtil.sHasKey(generateUserKey(appId, channelType), userAppInfo)) {
redisTempleUtil.sSet(generateUserKey(appId, channelType), userAppInfo);
}
} else if (channelType == SOCKET_CHANNEL_TYPE) {
if (!redisTempleUtil.sHasKey(SOCKET_ONLINE_APP_ID, appId)) {
redisTempleUtil.sSet(SOCKET_ONLINE_APP_ID, appId);
}
if (!redisTempleUtil.sHasKey(generateUserKey(appId, channelType), userAppInfo)) {
redisTempleUtil.sSet(generateUserKey(appId, channelType), userAppInfo);
}
}
//创建通道id-用户索引
if (!redisTempleUtil.sHasKey(getStrChannelId(channelId), userAppInfo)) {
redisTempleUtil.sSet(getStrChannelId(channelId), userAppInfo);
}
return l > 0 ? true : false;
}
}
private static String generateUserKey(String appId, Integer channelType) {
return USER_KEY + KEY_SPLIT + appId + KEY_SPLIT + channelType;
}
/**
* 添加websocket离线队列
*
* @param userInfoId:staffId@appId
* @param timeStamp:当前时间的毫秒值
* @return
*/
public static boolean addOutlineQueue(String userInfoId, long timeStamp) {
return redisTempleUtil.hset(OUTLINE_CACHE, userInfoId, timeStamp);
}
/**
* 获取离线websocket队列集合
* @return
*/
public static Map getOutlineEntrys() {
return redisTempleUtil.hEntrys(OUTLINE_CACHE);
}
/**
* 删除websocket离线队列
*
* @param userInfoId:staffId@appId
* @return
*/
public static Long delOutlineQueue(String userInfoId) {
if (redisTempleUtil.hHasKey(OUTLINE_CACHE, userInfoId)) {
return redisTempleUtil.hdel(OUTLINE_CACHE, userInfoId);
} else {
return 0L;
}
}
/**
* 根据用户id,和channelId移除关系
*
* @param userAppInfo
* @param channelId
* @return
*/
public static boolean removeUserChannelRelation(String appId, String userAppInfo, ChannelId channelId, Integer channelType) {
if (userAppInfo != null && channelId != null) {
//删除用户-通道关系
long l = redisTempleUtil.setRemove(userAppInfo, getStrChannelId(channelId));
//校验当前用户下的通道是否全部被删除,是,则删除后续索引
if (redisTempleUtil.sGetSetSize(userAppInfo) <= 0) {
//删除用户id索引
if (channelType == WS_CHANNEL_TYPE) {
long l1 = redisTempleUtil.setRemove(generateUserKey(appId, channelType), userAppInfo);
//校验当前产品下的用户id索引数是否为0,如果为0,则删除WS_ONLINE_APP_ID的产品索引
if (redisTempleUtil.sGetSetSize(generateUserKey(appId, channelType)) <= 0) {
long l2 = redisTempleUtil.setRemove(WS_ONLINE_APP_ID, appId);
}
} else if (channelType == SOCKET_CHANNEL_TYPE) {
long l1 = redisTempleUtil.setRemove(generateUserKey(appId, channelType), userAppInfo);
//校验当前产品下的用户id索引数是否为0,如果为0,则删除WS_ONLINE_APP_ID的产品索引
if (redisTempleUtil.sGetSetSize(generateUserKey(appId, channelType)) <= 0) {
long l2 = redisTempleUtil.setRemove(WS_ONLINE_APP_ID, appId);
}
}
}
return l > 0 ? true : false;
}
return false;
}
/**
* set检查key、value是否存在
*
* @param key
* @param value
* @return
*/
public static boolean hasSetIndex(String key, String value) {
return redisTempleUtil.sHasKey(key, value);
}
/**
* set检查key、value是否存在
*
* @param values
* @param key
* @return
*/
public static Set<Object> removeChannelRelation(String key, Integer channelKey, String... values) {
long l = redisTempleUtil.setRemove(key, values);
switch (key) {
case CHANNEL_KEY:
return removeChannelKeyAssociation(values, channelKey);
default:
return null;
}
}
/**
* set检查key、value是否存在
*
* @param key
* @return
*/
public static Set<Object> getSetValue(String key) {
return redisTempleUtil.sGet(key);
}
public static Set<Object> getUserKeySet(String appId, Integer channelType) {
return redisTempleUtil.sGet(generateUserKey(appId, channelType));
}
public static Set<Object> getUserKeyChannelIdSet(String appId, String userId) {
return redisTempleUtil.sGet(generateSetKey(userId, appId));
}
private static Set<Object> removeChannelKeyAssociation(String[] channelKeys, Integer channelType) {
Set<Object> result = new HashSet<>();
Assert.notNull(channelKeys, "channelKeys must not be null");
if (channelKeys.length > 0) {
for (String indexKey : channelKeys) {
Set<Object> userAppKeySet = redisTempleUtil.sGet(indexKey);
if (userAppKeySet != null && userAppKeySet.size() > 0) {
for (Object nexIndexKey : userAppKeySet) {
String userAppInfo = (String) nexIndexKey;
redisTempleUtil.setRemove(userAppInfo, indexKey);
//移除用户产品索引
String appId = userAppInfo.split(DEFAULT_MX)[1];
if (channelType == WS_CHANNEL_TYPE) {
redisTempleUtil.setRemove(generateUserKey(appId, channelType), userAppInfo);
redisTempleUtil.setRemove(WS_ONLINE_APP_ID, appId);
} else if (channelType == SOCKET_CHANNEL_TYPE) {
redisTempleUtil.setRemove(generateUserKey(appId, channelType), userAppInfo);
redisTempleUtil.setRemove(SOCKET_ONLINE_APP_ID, appId);
}
}
result.addAll(userAppKeySet);
}
//移除自身通道的索引
CHANNEL_MAP.remove(indexKey);
}
}
return result;
}
public static final String DEFAULT_MX = "@";
/**
* 生成用户-通道关系的key
*
* @param userInfoId
* @param appId
* @return
*/
public static String generateSetKey(String userInfoId, String appId) {
return userInfoId + DEFAULT_MX + appId;
}
/**
* 解析用户-通道关系的key
*
* @param userAppKey
* @return
*/
public static String[] decodeUserAppSetKey(String userAppKey) {
if (userAppKey.contains(DEFAULT_MX)) {
return userAppKey.split(DEFAULT_MX);
} else {
return null;
}
}
/**
* 获取排队序列索引set集合
*
* @return
*/
public static Set<Object> getQueueIndexSet() {
Set<Object> queueIndexSet = redisTempleUtil.sGet(QUEUE_APP_ID);
return queueIndexSet;
}
/**
* 向队列中右侧添加排队信息
*
* @param appId
* @param uId
* @param chatPackSeq
* @return value在队列中的索引
*/
public static long rightPushQueue(String appId, String uId, String chatPackSeq) {
String key = QUEUE_MARK + ":" + appId;
String value = encodeQueueValue(chatPackSeq, appId, uId);
long valueIndex = redisTempleUtil.lhasKeyAndValue(key, value);
if (valueIndex > -1) {
return valueIndex;
} else {
//存储排队序列索引appid
if (!redisTempleUtil.sHasKey(QUEUE_APP_ID, appId)) {
redisTempleUtil.sSet(QUEUE_APP_ID, appId);
}
return redisTempleUtil.lrightSet(key, value);
}
}
/**
* 校验并且返回当前排队信息的索引
*
* @param appId
* @param uId
* @param chatPackSeq
* @return value在队列中的索引
*/
public static long checkAndReturnIndex(String appId, String uId, String chatPackSeq) {
String key = QUEUE_MARK + ":" + appId;
String value = encodeQueueValue(chatPackSeq, appId, uId);
long valueIndex = redisTempleUtil.lhasKeyAndValue(key, value);
return valueIndex;
}
/**
* 校验并且返回当前排队信息的索引
*
* @param appId
* @return value在队列中的索引
*/
public static long getQueueSizeByAppId(String appId) {
String key = QUEUE_MARK + ":" + appId;
long size = redisTempleUtil.lGetListSize(key);
return size;
}
/**
* 从队列左侧拿出排队信息
*
* @param appId
* @return
*/
public static Object leftPopQueue(String appId) {
Object value = redisTempleUtil.leftPopListValue(QUEUE_MARK + ":" + appId);
return value;
}
/**
* 移除一个排队信息
*
* @param appId
* @return
*/
public static long removeOneInQueue(String appId, String uId, String chatPackSeq) {
String key = QUEUE_MARK + ":" + appId;
String value = encodeQueueValue(chatPackSeq, appId, uId);
long l = redisTempleUtil.lRemove(key, 1, value);
if (redisTempleUtil.lGetListSize(key) == 0) {
redisTempleUtil.setRemove(QUEUE_APP_ID, appId);
}
return l;
}
/**
* 编码排队的value
*
* @param chatPackSeq
* @param appId
* @param uId
* @return
*/
public static String encodeQueueValue(String chatPackSeq, String appId, String uId) {
return new StringBuffer(chatPackSeq).append(DEFAULT_MX).append(appId).append(DEFAULT_MX).append(uId).toString();
}
/**
* 解码排队的value
*
* @param value
* @return
*/
public static String[] decodeQueueValue(String value) {
if (value.contains(DEFAULT_MX)) {
return value.split(DEFAULT_MX);
} else {
return null;
}
}
}
websocket处理类(Handler)
import com.alibaba.fastjson.JSON;
import com.galaxyeye.icservice.im.MyChannelHandlePool;
import io.netty.channel.*;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.stereotype.Component;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;
/**
* @Author: geyingke
* @Date: 2020/7/20
* @Class: MyWebSocketHandler
* @Discription: websocket请求处理类
**/
@Component
@ChannelHandler.Sharable
public class WebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
private Logger logger = LogManager.getLogger(WebSocketHandler.class);
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame msg) throws Exception {
//do nothing
logger.info("come here~!");
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
String clientIp = socketAddress.getAddress().getHostAddress();
int clientPort = socketAddress.getPort();
ChannelId channelId = ctx.channel().id();
if (MyChannelHandlePool.channelGroup.containsKey(channelId)) {
logger.info(String.format("websocket客户端【%s】是连接状态,连接通道数量:%d", channelId, MyChannelHandlePool.channelGroup.size()));
} else {
//将channel添加到组
MyChannelHandlePool.channelGroup.put(channelId, ctx.channel());
logger.info(String.format("websocket客户端【%s】连接netty服务器[IP:%s--->PORT:%d]", channelId, clientIp, clientPort));
}
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
//首次请求为FullHttpRequest
if (null != msg && msg instanceof FullHttpRequest) {
FullHttpRequest fullHttpRequest = (FullHttpRequest) msg;
String uri = fullHttpRequest.uri();
Map<String, String> paramMap = getUrlParams(uri);
logger.info("received msg ==>" + JSON.toJSONString(paramMap));
//如果url包含参数,需要处理
if (uri.contains("?")) {
String newUri = uri.substring(0, uri.indexOf("?"));
fullHttpRequest.setUri(newUri);
} else {
fullHttpRequest.setUri("/ws");
}
} else if (msg instanceof TextWebSocketFrame) {
TextWebSocketFrame textWebSocketFrame = (TextWebSocketFrame) msg;
logger.info(String.format("服务端接收到的消息:%s", textWebSocketFrame.text()));
//todo:消息处理
sendMessage(ctx.channel(), textWebSocketFrame.text());
}
super.channelRead(ctx, msg);
}
private void sendMessage(Channel channel, String text) {
sendAllMessage(channel, text);
}
private void sendAllMessage(Channel channel, String message) {
//收到信息后,群发给所有channel
channel.writeAndFlush(new TextWebSocketFrame(message));
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
String clientIp = socketAddress.getAddress().getHostAddress();
int clientPort = socketAddress.getPort();
ChannelId channelId = ctx.channel().id();
if (MyChannelHandlePool.channelGroup.containsKey(channelId)) {
Channel remove = MyChannelHandlePool.channelGroup.remove(channelId);
if (remove != null) {
logger.info(String.format("websocket客户端【%s】成功下线![IP:%s--> PORT:%d]", channelId, clientIp, clientPort));
logger.info(String.format("websocket连接通道数量:%d", MyChannelHandlePool.channelGroup.size()));
} else {
logger.error(String.format("websocket客服端【%s】下线失败![IP:%s--> PORT:%d]", channelId, clientIp, clientPort));
}
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
logger.info("channelReadComplete");
ctx.flush();
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
String socketString = ctx.channel().remoteAddress().toString();
if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent) evt;
if (event.state() == IdleState.READER_IDLE) {
logger.info(String.format("Client: %s READER_IDLE 读超时", socketString));
ctx.disconnect();
} else if (event.state() == IdleState.WRITER_IDLE) {
logger.info(String.format("Client: %s WRITER_IDLE 写超时", socketString));
ctx.disconnect();
} else if (event.state() == IdleState.ALL_IDLE) {
logger.info(String.format("Client: %s ALL_IDLE 总超时", socketString));
ctx.disconnect();
}
} else {
super.userEventTriggered(ctx, evt);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.error("websocket消息处理异常!");
if (cause != null) cause.printStackTrace();
if (ctx != null) ctx.close();
}
private static Map getUrlParams(String url) {
Map<String, String> map = new HashMap<>();
url = url.replace("?", ";");
if (!url.contains(";")) {
return map;
}
if (url.split(";").length > 0) {
String[] arr = url.split(";")[1].split("&");
for (String s : arr) {
String key = s.split("=")[0];
String value = s.split("=")[1];
map.put(key, value);
}
return map;
} else {
return map;
}
}
}
socket处理类(Handler)
继承SimpleChannelInboundHandler<ByteBuf>,泛型不能跟websocket的handler的泛型相同,否则需要提取离线和在线的处理,不可分辨channel类型
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.galaxyeye.icservice.conf.RedisOperator;
import com.galaxyeye.icservice.conf.SpringContextBean;
import com.galaxyeye.icservice.conf.myException.CatchedReturnException;
import com.galaxyeye.icservice.conf.myException.DataBaseException;
import com.galaxyeye.icservice.constant.ReturnEnum;
import com.galaxyeye.icservice.entity.ValidatorVo;
import com.galaxyeye.icservice.im.MyChannelHandlePool;
import com.galaxyeye.icservice.im.parser.SocketUtils;
import com.galaxyeye.icservice.im.protocol.SocketProtocol;
import com.galaxyeye.icservice.im.protocol.WebSocketProtocol;
import com.galaxyeye.icservice.service.im.socket.SocketHandlerService;
import com.galaxyeye.icservice.utils.RandomUtils;
import com.galaxyeye.icservice.utils.WSMsgUtil;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.ReferenceCountUtil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.lang.invoke.MethodHandles;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
/**
* @Author: geyingke
* @Date: 2020/7/21
* @Class: NettyWebSocketHandler
* @Discription: TODO
**/
public class NettySocketHandler extends SimpleChannelInboundHandler<ByteBuf> {
private Logger logger = LogManager.getLogger(MethodHandles.lookup().lookupClass());
private RedisOperator redisOperator = SpringContextBean.getBean(RedisOperator.class);
private SocketHandlerService socketHandlerService = SpringContextBean.getBean(SocketHandlerService.class);
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
String clientIp = socketAddress.getAddress().getHostAddress();
int clientPort = socketAddress.getPort();
ChannelId channelId = ctx.channel().id();
if (MyChannelHandlePool.hasChannel(channelId)) {
logger.info(String.format("socket客户端【%s】是连接状态,连接通道数量:%d", channelId, MyChannelHandlePool.getChannelGroupSize()));
} else {
//将channel添加到组
MyChannelHandlePool.addChannel(channelId, ctx.channel());
logger.info(String.format("socket客户端【%s】连接netty服务器[IP:%s--->PORT:%d]", channelId, clientIp, clientPort));
logger.info(String.format("客户端连接通道数量:%d", MyChannelHandlePool.getChannelGroupSize()));
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
logger.info("---------------------socket断线检测-------------------------");
InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
String clientIp = socketAddress.getAddress().getHostAddress();
int clientPort = socketAddress.getPort();
ChannelId channelId = ctx.channel().id();
if (MyChannelHandlePool.hasChannel(channelId)) {
boolean aLong = MyChannelHandlePool.delChannel(ctx.channel());
//校验通道id,并下线
if (MyChannelHandlePool.hasSetIndex(MyChannelHandlePool.CHANNEL_KEY, MyChannelHandlePool.getStrChannelId(channelId))) {
MyChannelHandlePool.removeChannelRelation(MyChannelHandlePool.CHANNEL_KEY, MyChannelHandlePool.SOCKET_CHANNEL_TYPE, MyChannelHandlePool.getStrChannelId(channelId));
}
if (aLong) {
logger.info(String.format("socket客户端【%s】成功下线![IP:%s--> PORT:%d]", channelId, clientIp, clientPort));
logger.info(String.format("连接通道数量:%d", MyChannelHandlePool.getChannelGroupSize()));
} else {
logger.error(String.format("socket客户端【%s】下线失败![IP:%s--> PORT:%d]", channelId, clientIp, clientPort));
}
}
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
logger.info("come here");
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
try {
ByteBuf byteBuf = (ByteBuf) msg;
Map<String, Object> reqMap = SocketUtils.parseByteBuffMap(byteBuf);
String msgBody = (String) reqMap.get(SocketUtils.BODY);
int cmd = (int) reqMap.get(SocketUtils.CMD);
logger.info("msg coverted : " + msgBody);
JSONObject receiveMsg = JSON.parseObject(msgBody);
if (receiveMsg.containsKey(SocketProtocol.TYPE)) {
String msgType = receiveMsg.getString(SocketProtocol.TYPE);
if (MyChannelHandlePool.AUTH_CHANNEL.contains(ctx.channel())) {
switch (msgType) {
case SocketProtocol.TRANS_IC:
tranIc(ctx, receiveMsg);
break;
case SocketProtocol.GET_QUEUE_INFO:
getQueueInfo(ctx, receiveMsg);
break;
case SocketProtocol.EXIT_QUEUE:
exitQueue(ctx, receiveMsg);
break;
case SocketProtocol.CHAT_SEND:
chatSend(ctx, receiveMsg);
break;
case SocketProtocol.FEEDBACK:
feedBack(ctx, receiveMsg);
break;
case SocketProtocol.CHAT_OFFLINE:
chatOffline(ctx, receiveMsg);
break;
case SocketProtocol.CHAT_EXIT:
chatExit(ctx, receiveMsg);
break;
case SocketProtocol.CHAT_RECONNECT:
reconnect(ctx, receiveMsg);
break;
default:
break;
}
} else {
JSONObject res = new JSONObject() {{
put("type", "error");
put("retCode", ReturnEnum.UNAUTHED_CHANNEL.getRet_msg());
put("retMsg", ReturnEnum.UNAUTHED_CHANNEL.getRet_msg());
}};
writeBack(ctx, res.toJSONString());
}
} else {
if (cmd == 201) {
Integer servType = receiveMsg.getInteger("servType");
String appid = receiveMsg.getString("appid");
handle201msg(ctx, appid, servType);
} else if (cmd == 202) {
String appid = receiveMsg.getString("appid");
String sign = receiveMsg.getString("sign");
handle202msg(ctx, appid, sign);
} else {
//默认的心跳包处理,直接返回
logger.info("return msg: " + msgBody);
writeHeartBeatBack(ctx, msgBody);
}
}
} catch (CatchedReturnException ce) {
logger.error("消息内容不合法!", ce);
String dispose = JSON.toJSONString(new HashMap<String, Object>() {{
put(SocketProtocol.TYPE, SocketProtocol.ERROR);
put(SocketProtocol.DATA, new StringBuffer("消息内容不合法!:").append(ce.getMessage()));
}});
writeBack(ctx, dispose);
} catch (Exception e) {
logger.error("消息处理异常:", e);
String dispose = JSON.toJSONString(new HashMap<String, Object>() {{
put(SocketProtocol.TYPE, SocketProtocol.ERROR);
put(SocketProtocol.DATA, new StringBuffer("消息内容不合法!:").append(e.getMessage()));
}});
writeBack(ctx, dispose);
} finally {
//释放内存
ReferenceCountUtil.release(msg);
}
}
/**
* 处理返回信息
*
* @param ctx
* @param dispose
*/
private void writeBack(ChannelHandlerContext ctx, String dispose) {
ByteBuf resp = Unpooled.copiedBuffer(SocketUtils.pack(dispose.getBytes(), 0));
ctx.writeAndFlush(resp);
}
/**
* 处理返回信息
*
* @param ctx
* @param dispose
*/
private void writeHeartBeatBack(ChannelHandlerContext ctx, String dispose) {
ByteBuf resp = Unpooled.copiedBuffer(SocketUtils.pack(dispose.getBytes(), 9999));
ctx.writeAndFlush(resp);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
String socketString = ctx.channel().remoteAddress().toString();
if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent) evt;
if (event.state() == IdleState.READER_IDLE) {
logger.info(String.format("Client: %s READER_IDLE 读超时", socketString));
ctx.disconnect();
} else if (event.state() == IdleState.WRITER_IDLE) {
logger.info(String.format("Client: %s WRITER_IDLE 写超时", socketString));
ctx.disconnect();
} else if (event.state() == IdleState.ALL_IDLE) {
logger.info(String.format("Client: %s ALL_IDLE 总超时", socketString));
ctx.disconnect();
}
} else {
super.userEventTriggered(ctx, evt);
}
}
}
socket消息封包、解包处理类:
import io.netty.buffer.ByteBuf;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
/**
* @Author: geyingke
* @Date: 2020/8/4
* @Class: SocketUtils
* @Discription: TODO
**/
public class SocketUtils {
static final int HEAD_SIZE = 10;
static final int TOTAL_SIZE = 14;
static int cmd;
private static final Logger logger = LogManager.getLogger(SocketUtils.class);
static void write_short_le(byte[] buf, int offset, short value) {
buf[offset + 1] = (byte) ((value >> 8) & 0xff);//说明一
buf[offset + 0] = (byte) ((value) & 0xff);
}
static void write_int_le(byte[] buf, int offset, int value) {
buf[offset + 3] = (byte) ((value >> 24) & 0xff);//说明一
buf[offset + 2] = (byte) ((value >> 16) & 0xff);
buf[offset + 1] = (byte) ((value >> 8) & 0xff);
buf[offset + 0] = (byte) (value & 0xff);
}
static void write_bytes(byte[] src, int src_offset, byte[] dst, int dst_offset) {
for (int i = 0; i < src.length - src_offset; ++i) {
dst[dst_offset + i] = src[src_offset + i];
}
}
static short read_short_le(byte[] data, int offset) {
int ret = (data[offset] | (data[offset + 1] << 8)) & 0xFF;
return (short) ret;
}
public static int read_int_le(byte[] data, int offset) {
int ret = ((data[offset] & 0xFF) | ((data[offset + 1] & 0xFF) << 8) | ((data[offset + 2] & 0xFF << 16)) | ((data[offset + 3] & 0xFF << 24)));
return ret;
}
/**
* 解析byteBuf内容
*
* @param msg
* @return
*/
public static String parseByteBuff(ByteBuf msg) {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
logger.debug("msg before covert: " + new String(bytes));
return parse(bytes);
}
/**
* 解析tcp body
*
* @param bytes
* @return
*/
public static String parse(byte[] bytes) {
int offset = 0;
int plen = read_int_le(bytes, offset);
offset += 4;//pkgLen
offset += 4;//checkSum
cmd = read_short_le(bytes, offset);
offset += 2;//cmd
offset += 2;//target
offset += 2;//retCode
int content_size = (plen - HEAD_SIZE);
byte[] content_buf = new byte[content_size];
write_bytes(bytes, offset, content_buf, 0);
return new String(content_buf);
}
public static final String CMD = "CMD";
public static final String BODY = "BODY";
/**
* 解析byteBuf内容
*
* @param msg
* @return
*/
public static Map<String, Object> parseByteBuffMap(ByteBuf msg) {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
logger.debug("msg before covert: " + new String(bytes));
return parseMap(bytes);
}
/**
* 解析tcp body
*
* @param bytes
* @return
*/
public static Map<String, Object> parseMap(byte[] bytes) {
int offset = 0;
int plen = read_int_le(bytes, offset);
offset += 4;//pkgLen
offset += 4;//checkSum
int cmdx = read_short_le(bytes, offset);
offset += 2;//cmd
offset += 2;//target
offset += 2;//retCode
int content_size = (plen - HEAD_SIZE);
byte[] content_buf = new byte[content_size];
write_bytes(bytes, offset, content_buf, 0);
return new HashMap<String, Object>() {{
put(CMD, cmdx);
put(BODY, new String(content_buf));
}};
}
/**
* 打包tcp body
*
* @param content
* @param cmd
* @return
*/
public static byte[] pack(byte[] content, int cmd) {
int total_size = content.length + TOTAL_SIZE;
int pkgLen = total_size - 4;
int offset = 0;
byte[] msg = new byte[total_size];
write_int_le(msg, offset, pkgLen);
offset += 4;//pkgLen
write_int_le(msg, offset, 0);
offset += 4;//checkSum
write_short_le(msg, offset, (short) cmd);
offset += 2;//cmd
write_short_le(msg, offset, (short) 0);
offset += 2;//target
write_short_le(msg, offset, (short) 0);
offset += 2;//retCode
write_bytes(content, 0, msg, offset);
return msg;
}
}