dwsocket.h
#pragma once
#include <string>
#include <sys/epoll.h>
#include <pthread.h>
#include<condition_variable>
#include <queue>
#include<stdlib.h>
namespace dw
{
using namespace std;
class Addr
{
public:
uint16_t port;
string ip;
public:
Addr(string ip, uint16_t port);
Addr();
~Addr();
friend ostream& operator<<(ostream &out, const Addr &addr);
private:
};
class Socket
{
public:
int socketFd;
public:
Socket();
~Socket();
int read(void *data);
private:
};
typedef void(*ServerSocketCallBack)(Socket*);
/*
* 服务端Socket 完全异步
*/
class ServerSocket
{
public:
Addr *addr = NULL;
int maxListenNum = 128;//最大监听数,libevent也是128
int epollSize = 65535;//epoll 的最大监听数
int maxEvents = 1024;
int maxThread = 1; //最大线程数(socket开启时,会会根据cpu核心数,重新初始化该值)
ServerSocketCallBack callBack = NULL;
public:
ServerSocket(uint16_t port);
~ServerSocket();
int start();
private:
int socketfd = -1;
int epollfd = -1;
epoll_event* events;
int *threadArry = NULL;
queue<int> taskQueue; //任务队列
mutex taskQueueLock; //任务队列锁,防止两个线程读取同一个socket
private:
void initAddr(uint16_t port);
int initSocket();
int initEpoll();
void setNonBlock(int fd); //设置非阻塞socket
void initMaxThread();
void loopWait();
void addSocketFd(int fd);
int initThreadPool(); //初始化线程池
static void * threadMain(void* arg);//线程主函数
};
}
Addr.cpp
#include "dwsocket.h"
#include<iostream>
namespace dw
{
Addr::Addr()
{
}
Addr::Addr(string ip, uint16_t port)
{
this->port = port;
this->ip = ip;
}
Addr::~Addr()
{
}
ostream& operator<<(ostream &out, const Addr &addr)
{
out << addr.ip << ":" << addr.port;
return out;
}
}
ServerSocket.cpp
#include "dwsocket.h"
#include<iostream>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <sys/sysinfo.h>
#include <unistd.h>
namespace dw
{
using namespace std;
ServerSocket::ServerSocket(uint16_t port)
{
initAddr(port);
initMaxThread();
}
ServerSocket::~ServerSocket()
{
delete addr;
if (threadArry != NULL) {
for (int i = 0; i < maxThread; i++) {
pthread_exit(&(threadArry[i]));
}
}
shutdown(socketfd, SHUT_RDWR);
}
void ServerSocket::initMaxThread()
{
int cpuCoreNum = get_nprocs();
maxThread = cpuCoreNum * 2;
}
void ServerSocket::initAddr(uint16_t port)
{
addr = new Addr("127.0.0.1", port);
}
int ServerSocket::initSocket()
{
if (maxListenNum <= 0)
{
cout << "Error: maxListenNum 必须大于0 " << endl;
return -1;
}
socketfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (socketfd < 0)
{
cout << "Error: socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)" << endl;
return -1;
}
//设置socket
setNonBlock(socketfd);
int opt = 1;
setsockopt(socketfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
struct sockaddr_in server_addr;
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
server_addr.sin_port = htons(addr->port);
int err = bind(socketfd, (struct sockaddr*) &server_addr, sizeof(server_addr));
if (err < 0)
{
cout << "Error: bind(socketfd, (struct sockaddr*) &server_addr, sizeof(server_addr))" << endl;
return -1;
}
err = listen(socketfd, maxListenNum);
if (err < 0)
{
cout << "Error: listen(socketfd, maxListenNum)" << endl;
return -1;
}
return 0;
}
void ServerSocket::setNonBlock(int fd)
{
int fl = fcntl(socketfd, F_GETFL);
fcntl(socketfd, F_SETFL, fl | O_NONBLOCK);
}
int ServerSocket::initEpoll()
{
epollfd = epoll_create(epollSize);
if (epollfd < 0)
{
cout << "Error: epoll_create(maxEvents)" << endl;
return -1;
}
events = new epoll_event[maxEvents];
addSocketFd(socketfd);
return 0;
}
void ServerSocket::loopWait()
{
struct sockaddr_in clientaddr;
socklen_t clilen = sizeof(clientaddr);
while (1)
{
int num = epoll_wait(epollfd, events, maxEvents, -1); //返回活跃用户个数
for (int i = 0; i < num; i++)
{
int connfd = -1;
if (events[i].data.fd == socketfd)
{
cout << "有新用户连接" << endl;
connfd = accept(socketfd, (struct sockaddr *) &clientaddr, &clilen);
if (connfd < 0)
{
cout << "连接失败" << endl;
continue;
}
setNonBlock(connfd);
}
else if (events[i].events&EPOLLIN)
{
cout << "有用户发送数据" << endl;
if ((connfd = events[i].data.fd) < 0)
{
cout << "连接失败" << endl;
continue;
}
taskQueueLock.lock();
taskQueue.push(connfd);
taskQueueLock.unlock();
}
else if (events[i].events&EPOLLOUT)
{
//以前的连接,有数据写出
cout << "以前的连接 写出" << endl;
}
else
{
cout << "其他" << endl;
continue;
}
addSocketFd(connfd);
}
}
}
void ServerSocket::addSocketFd(int fd)
{
struct epoll_event event;
event.data.fd = fd;
event.events = EPOLLIN | EPOLLET;
epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &event);
}
int ServerSocket::initThreadPool()
{
pthread_t threads[maxThread];
threadArry = (int*)threads;
for (int i = 0; i < maxThread; i++)
{
int ret = pthread_create(&threads[i], NULL, threadMain, this);
if (ret < 0)
{
cout << "Error: initThreadPool()" << endl;
return -1;
}
}
return 0;
}
void * ServerSocket::threadMain(void* arg)
{
ServerSocket *socket = (ServerSocket*)arg;
while (true)
{
int socketFd = 0;
socket->taskQueueLock.lock();
if (!socket->taskQueue.empty()) {
socketFd = socket->taskQueue.front();
socket->taskQueue.pop();
}
socket->taskQueueLock.unlock();
if (socketFd <= 0)
{
usleep(100000);
}
else
{
Socket client;
client.socketFd = socketFd;
socket->callBack(&client);
}
}
return NULL;
}
int ServerSocket::start()
{
if (callBack == NULL) {
cout << "Error: callBack is NULL" << endl;
return -1;
}
int err = initThreadPool();
if (err < 0)
{
cout << "Error: initThreadPool()" << endl;
return -1;
}
err = initSocket();
if (err < 0)
{
cout << "Error: initSocket()" << endl;
return -1;
}
err = initEpoll();
if (err < 0)
{
cout << "Error: initEpoll()" << endl;
return -1;
}
loopWait();
return 0;
}
}
Socket.cpp
#include "dwsocket.h"
namespace dw {
Socket::Socket() {
}
Socket::~Socket() {
}
int Socket::read(void *data) {
return -1;
}
}
main.cpp
#include "dwsocket.h"
#include <iostream>
using namespace dw;
using namespace std;
void test(Socket *socket) {
cout << "test 有用户连接进来" << endl;
socket->read(NULL);
}
int main() {
ServerSocket server(10000);
static ServerSocketCallBack callback = test;
server.callBack = callback;
server.start();
return 0;
}
编译时需要添加 -lpthread