golang tcp 转发

前端之家收集整理的这篇文章主要介绍了golang tcp 转发前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。
package main

import (
  "os"
  "io"
  "fmt"
  "net"
  "strings"
  "strconv"
  "syscall"
  "encoding/binary"
)

type CSPair struct {
  clientaddr net.Addr
  serveraddr net.Addr
  clientconn *net.TCPConn
  serverconn *net.TCPConn
}

const (
  SO_ORIGINAL_DST = 80
)

var (
  connection_count = 0
)

func main() {
  laddr := &net.TCPAddr{}
  laddr.Port = 8838
  ln,err := net.ListenTCP("tcp4",laddr)
  handle_error(err)
  fmt.Printf("listen on %d\n",laddr.Port)
  defer ln.Close()

  for {
    conn,err := ln.AcceptTCP()
    handle_error(err)
    pair := construct_connection(conn)
    handle_data(pair)
  }
}

func handle_data(pair *CSPair) {
  go handle_cs(pair)
  go handle_sc(pair)
}

func handle_cs(pair *CSPair) {
  defer pair.clientconn.Close()
  if strings.Index(pair.serveraddr.String(),":843") != -1 {
    fmt.Println(":843 connection.")
    io.Copy(pair.serverconn,pair.clientconn)
    return
  }

  var remain_data []byte
  for {
    bs,err := readPacket(pair.clientconn)
    handle_error(err)
    remain_data = append(remain_data,bs...)
    packet_len := int(binary.LittleEndian.Uint32(remain_data))
    packet_len += 4 //fixed len.
    fmt.Printf("remain_data: 0x%x,packet_len: 0x%x\n",len(remain_data),packet_len)
    if packet_len > len(remain_data) {
      continue
    }
    packet_data := remain_data[:packet_len]
    remain_data = remain_data[packet_len:]
    //packet_data = append(packet_data,0)
    fmt.Printf("receive 0x%x:%s\n",packet_len,string(packet_data))
    n,err := pair.serverconn.Write(packet_data)
    handle_error(err)
    fmt.Printf("handle_cs write 0x%x bytes\n",n)
  }
}

func handle_sc(pair *CSPair) {
  defer pair.serverconn.Close()
  io.Copy(pair.clientconn,pair.serverconn)
  fmt.Println("handle_sc close pair.serverconn")
  /*
  bs,err := readPacket(pair.serverconn)
  handle_error(err)
  fmt.Printf("handle_sc:%s\n",string(bs))
  pair.clientconn.Write(bs)
  */
}

func construct_connection(c *net.TCPConn) *CSPair {
  var pair = &CSPair{}
  pair.clientconn = c
  pair.clientaddr = (*c).RemoteAddr()
  f,err := c.File()
  handle_error(err)

  addr,err := syscall.GetsockoptIPv6Mreq(int(f.Fd()),syscall.IPPROTO_IP,SO_ORIGINAL_DST)
  handle_error(err)

  ipv4 := strconv.Itoa(int(addr.Multiaddr[4])) + "." +
      strconv.Itoa(int(addr.Multiaddr[5])) + "." +
      strconv.Itoa(int(addr.Multiaddr[6])) + "." +
      strconv.Itoa(int(addr.Multiaddr[7]))

  port := uint16(addr.Multiaddr[2]) << 8 + uint16(addr.Multiaddr[3])
  origin_ipv4 := ipv4
  origin_port := port

  sa,err := net.ResolveTCPAddr("tcp4",fmt.Sprintf("%s:%d",ipv4,port))
  handle_error(err)
  pair.serveraddr = sa
  pair.serverconn,err = net.DialTCP("tcp4",nil,sa)
  handle_error(err)

  connection_count++
  fmt.Printf("accept %d,%s and create a new connection to server %s(%s:%d)\n",connection_count,pair.clientaddr.String(),pair.serveraddr.String(),origin_ipv4,origin_port)
  return pair
}

func handle_error(err error) {
  if err != nil {
    fmt.Println(err)
    os.Exit(1)
  }
}
原文链接:https://www.f2er.com/go/189352.html

猜你在找的Go相关文章