在做web开发,公司提出一个内网穿透的需求,考虑使用花生壳之类的工具,原来对这部分内容一直比较感兴趣,顺手研究了一下;
代码在最后,先整理一下思路:
1.
外网用户访问一个网页,首先需要一个固定的地址/域名,这样你必须有一台拥有公网IP的服务器,来实现
用户 ==> 中介服务器
这个请求。
2.
当外网用户访问中介服务器之后,怎么从内网服务器获取数据呢?
首先外网服务器不能主动访问内网内容,但是可以被动应答;
所以我们可以在中介服务器开启一个tcp服务监听,让内网主机主动连接中介服务器,当需要获取数据的时候就可以通过这种方式获取数据;这样
中介服务器 <== 内网主机的链接就实现了
3.
然后链接1,2步骤中的内容就可以实现内网穿透了,当然这个要求内网主机是可以访问外网资源的,
用户 <==> 中介服务器 <==> 内网主机
如果内网主机不能访问外网资源,那就需要一台能否同时访问内网和外网的主机来实现
用户 <==> 中介服务器 <==> DMZ主机 <==> 内网主机
大概的思路就是这样,实现中还需要一个在内网里运行的程序中做一个控制指令tcp长链接,用于当用户请求网址时中介主机向内网发送通知说“用户请求内容了,赶紧跟我的隧道监听端口通信”,这样才能建立起一个临时的通信隧道。
代码如下
安装在中介服务器上的程序
package main
import (
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
)
/**
本程序为用户可以直接访问的中介服务器上运行的穿透服务端
工作流程:
1启动 ControlPort 的监听,等待在内网环境运行的受控端连接
(内网主机要连接到这个端口等待控制指令和心跳)
2启动 ListenPort 的监听,等待用户的web请求
3启动 TunnelPort 的监听,等待内网受控端链接隧道
4启动定时释放任务,清楚过期链接和无效链接
5创建一个阻塞等待配对器组装
用户访问时:
1. ListenPort 接受到用户的请求,
通过addConnMatchAccept新建(未完全配置,tunnel为空)一个配对器ConnMatch到全局map connListMap中,
2.通过ControlPort向内网被控主机发送一个链接通知“new\n”
3.-内网主机在接收到消息后,打通真实web服务器和TunnelPort的隧道
4.监听在TunnelPort的程序得到一个新的请求链接
5.在makeForward方法中通过configConnListTunnel方法将1步骤创建的不完整的配对器补充完整
6.通过向通道connListMapUpdate中传值通信,运行tcpForward中joinConn方法,将来用户的链接和来自内网客户端的链接绑定
从而实现内网穿透
*/
func main() {
//监听控制端口8009
go makeControl()
//监听服务端口8007
go makeAccept()
//监听转发端口8008
go makeForward()
//定时释放连接
go releaseConnMatch()
//执行tcp转发
tcpForward()
}
const (
//与安装在内网服务器的client通信长链接接口
ControlPort = ":8009"
//链路实际通信端口
TunnelPort = ":8008"
//用户通信监听端口,即用户严重的web服务器地址
ListenPort = ":8007"
)
var cache *net.TCPConn = nil
func makeControl() {
var tcpAddr *net.TCPAddr
tcpAddr, _ = net.ResolveTCPAddr("tcp", ControlPort)
//打开一个tcp断点监听
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
panic(err)
}
fmt.Println("控制端口已经监听")
for {
tcpConn, err := tcpListener.AcceptTCP()
if err != nil {
panic(err)
}
fmt.Println("新的客户端连接到控制端服务进程:" + tcpConn.RemoteAddr().String())
if cache != nil {
fmt.Println("已经存在一个客户端连接!")
//直接关闭掉多余的客户端请求
tcpConn.Close()
} else {
cache = tcpConn
}
go control(tcpConn)
}
}
func control(conn *net.TCPConn) {
go func() {
for {
//一旦有客户端连接到服务端的话,服务端每隔2秒发送hi消息给到客户端
//如果发送不出去,则认为链路断了,清除cache连接
_, e := conn.Write(([]byte)("hi\n"))
if e != nil {
cache = nil
}
time.Sleep(time.Second * 2)
}
}()
}
func makeAccept() {
var tcpAddr *net.TCPAddr
tcpAddr, _ = net.ResolveTCPAddr("tcp", ListenPort)
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
panic(err)
}
defer tcpListener.Close()
for {
tcpConn, err := tcpListener.AcceptTCP()
if err != nil {
fmt.Println(err)
continue
}
fmt.Println("A client connected 8007:" + tcpConn.RemoteAddr().String())
addConnMatchAccept(tcpConn)
sendMessage("new\n")
}
}
//配对器
type ConnMatch struct {
accept *net.TCPConn //8007 tcp链路 accept
acceptAddTime int64 //接受请求的时间
tunnel *net.TCPConn //8008 tcp链路 tunnel
}
var connListMap = make(map[string]*ConnMatch)
var lock = sync.Mutex{}
func addConnMatchAccept(accept *net.TCPConn) {
//加锁防止竞争读写map
lock.Lock()
defer lock.Unlock()
now := time.Now().UnixNano()
connListMap[strconv.FormatInt(now, 10)] = &ConnMatch{accept, time.Now().Unix(), nil}
}
func sendMessage(message string) {
fmt.Println("send Message " + message)
if cache != nil {
_, e := cache.Write([]byte(message))
if e != nil {
fmt.Println("消息发送异常")
fmt.Println(e.Error())
}
} else {
fmt.Println("没有客户端连接,无法发送消息")
}
}
func makeForward() {
var tcpAddr *net.TCPAddr
tcpAddr, _ = net.ResolveTCPAddr("tcp", TunnelPort)
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
panic(err)
}
defer tcpListener.Close()
fmt.Println("Server ready to read ...")
for {
tcpConn, err := tcpListener.AcceptTCP()
if err != nil {
fmt.Println(err)
continue
}
fmt.Println("A client connected 8008 :" + tcpConn.RemoteAddr().String())
configConnListTunnel(tcpConn)
}
}
var connListMapUpdate = make(chan int)
func configConnListTunnel(tunnel *net.TCPConn) {
//加锁解决竞争问题//todo
lock.Lock()
used := false
for _, connMatch := range connListMap {
//找到tunnel为nil的而且accept不为nil的connMatch
if connMatch.tunnel == nil && connMatch.accept != nil {
//填充tunnel链路
connMatch.tunnel = tunnel
used = true
//这里要break,是防止这条链路被赋值到多个connMatch!
break
}
}
if !used {
//如果没有被使用的话,则说明所有的connMatch都已经配对好了,直接关闭多余的8008链路
fmt.Println(len(connListMap))
_ = tunnel.Close()
fmt.Println("关闭多余的tunnel")
}
lock.Unlock()
//使用channel机制来告诉另一个方法已经就绪
connListMapUpdate <- 0
}
func tcpForward() {
for {
select {
case <-connListMapUpdate:
lock.Lock()
for key, connMatch := range connListMap {
//如果两个都不为空的话,建立隧道连接
if connMatch.tunnel != nil && connMatch.accept != nil {
fmt.Println("建立tcpForward隧道连接")
go joinConn(connMatch.accept, connMatch.tunnel)
//从map中删除
delete(connListMap, key)
}
}
lock.Unlock()
}
}
}
func joinConn(conn1 *net.TCPConn, conn2 *net.TCPConn) {
f := func(local *net.TCPConn, remote *net.TCPConn) {
//defer保证close
defer local.Close()
defer remote.Close()
//使用io.Copy传输两个tcp连接,
_, err := io.Copy(local, remote)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println("join Conn2 end")
}
go f(conn2, conn1)
go f(conn1, conn2)
}
func releaseConnMatch() {
for {
lock.Lock()
for key, connMatch := range connListMap {
//如果在指定时间内没有tunnel的话,则释放该连接
if connMatch.tunnel == nil && connMatch.accept != nil {
if time.Now().Unix()-connMatch.acceptAddTime > 5 {
fmt.Println("释放超时连接")
err := connMatch.accept.Close()
if err != nil {
fmt.Println("释放连接的时候出错了:" + err.Error())
}
delete(connListMap, key)
}
}
}
lock.Unlock()
time.Sleep(5 * time.Second)
}
}
安装在内网服务器或者dmz主机的程序如下:
package main
import (
"bufio"
"fmt"
"io"
"net"
)
/**
本程序为内网环境中web服务器所在的主机(或者可连接到内网web服务器同时可以访问外网的间机器)
工作流程:
1.连接远端服务器ControlAddrPort,接受远端服务器的控制命令
2.当用户访问远端服务器时,ControlAddrPort传来控制命令"new\n",
3.执行combine方法,程序同时拨通TunnelAddrPort 和 ServerAddrPort,
4.并通过joinConn方法,用io.Copy的方式讲TunnelAddrPort的通信数据和ServerAddrPort的通信数据配对,
实现将内网数据提交到中介服务器ControlAddrPort,再通过中介服务期上的程序实现内网穿透
*/
func main() {
connectControl()
}
const (
//中介服务器控制端程序连接地址和端口
//我在内网测试所以填写内网地址,不要被误导
ControlAddrPort = "192.168.3.99:8009"
//链路实际通信连接地址和端口
//我在内网测试所以填写内网地址,不要被误导
TunnelAddrPort = "192.168.3.99:8008"
//内网服务程序地址和端口
ServerAddrPort = "127.0.0.1:80"
)
//连接到服务器的8009控制端口,随时接受服务器的控制请求,随时待命
func connectControl() {
var tcpAddr *net.TCPAddr
//这里在一台机测试,所以没有连接到公网,可以修改到公网ip
tcpAddr, _ = net.ResolveTCPAddr("tcp", ControlAddrPort)
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
fmt.Println("Client connect error ! " + err.Error())
return
}
fmt.Println(conn.LocalAddr().String() + " : Client connected!8009")
reader := bufio.NewReader(conn)
for {
s, err := reader.ReadString('\n')
if err != nil || err == io.EOF {
break
} else {
//接收到new的指令的时候,新建一个tcp连接
if s == "new\n" {
go combine()
}
if s == "hi" {
//忽略掉hi的请求
}
}
}
}
//combine方法的代码,整合local和remote的tcp连接
func combine() {
local := connectLocal()
remote := connectRemote()
if local != nil && remote != nil {
joinConn(local, remote)
} else {
if local != nil {
err := local.Close()
if err!=nil{
fmt.Println("close local:" + err.Error())
}
}
if remote != nil {
err := remote.Close()
if err!=nil{
fmt.Println("close remote:" + err.Error())
}
}
}
}
func joinConn(local *net.TCPConn, remote *net.TCPConn) {
f := func(local *net.TCPConn, remote *net.TCPConn) {
defer local.Close()
defer remote.Close()
_, err := io.Copy(local, remote)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println("end")
}
go f(local, remote)
go f(remote, local)
}
//connectLocal 连接到内网web服务器!
func connectLocal() *net.TCPConn {
var tcpAddr *net.TCPAddr
tcpAddr, _ = net.ResolveTCPAddr("tcp", ServerAddrPort)
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
fmt.Println("Client connect error ! " + err.Error())
return nil
}
fmt.Println(conn.LocalAddr().String() + " : Client connected!8000")
return conn
}
//connectRemote 连接到服务端的8008端口!
func connectRemote() *net.TCPConn {
var tcpAddr *net.TCPAddr
tcpAddr, _ = net.ResolveTCPAddr("tcp", TunnelAddrPort)
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
fmt.Println("Client connect error ! " + err.Error())
return nil
}
fmt.Println(conn.LocalAddr().String() + " : Client connected!8008")
return conn;
}