Golang1.7中的TCP交互封装.

前端之家收集整理的这篇文章主要介绍了Golang1.7中的TCP交互封装.前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。
package connection

import (
    "encoding/binary"
    "errors"
    "io"
    "net"
    "sync"
    "time"
)

//支持的最大消息长度
const maxLength int = 1<<32 - 1 // 4294967295

var (
    rHeadBytes = [4]byte{0,0,0}
    wHeadBytes = [4]byte{0,0}
    errMsgRead = errors.New("Message read length error")
    errHeadLen = errors.New("Message head length error")
    errMsgLen  = errors.New("Message length is no longer in normal range")
)
var connPool sync.Pool

//从对象池中获取一个对象,不存在则申明
func Newconnection(conn net.Conn) Conn {
    c := connPool.Get()
    if cnt,ok := c.(*connection); ok {
        cnt.rwc = conn
        return cnt
    }
    return &connection{rlen: 0,rwc: conn}
}

type Conn interface {
    Read() (r io.Reader,size int,err error)
    Write(p []byte) (n int,err error)
    Writer(size int,r io.Reader) (n int64,err error)
    RemoteAddr() net.Addr
    LocalAddr() net.Addr
    SetDeadline(t time.Time) error
    SetReadDeadline(t time.Time) error
    SetWriteDeadline(t time.Time) error
    Close() (err error)
}


//定义一个结构体,用来封装Conn.
type connection struct {
    rlen  int         //消息长度
    rwc   net.Conn    //原始的网络链接
    rlock sync.Mutex  //Conn读锁
    wlock sync.Mutex  //Conn写锁
}

//此方法用来读取头部消息.
func (self *connection) rhead() error {
    n,err := self.rwc.Read(rHeadBytes[:])
    if n != 4 || err != nil {
        if err != nil {
            return err
        }
        return errHeadLen
    }
    self.rlen = int(binary.BigEndian.Uint32(rHeadBytes[:]))
    return nil
}
//此方法用来发送头消息
func (self *connection) whead(l int) error {
    if l <= 0 || l > maxLength {
        return errMsgLen
    }
    binary.BigEndian.PutUint32(wHeadBytes[:],uint32(l))
    _,err := self.rwc.Write(wHeadBytes[:])
    return err
}

//头部消息解析之后.返回一个io.Reader借口.用来读取远端发送过来的数据
//封装成limitRead对象.来实现ioReader接口
func (self *connection) Read() (r io.Reader,err error) {
    self.rlock.Lock()
    if err = self.rhead(); err != nil {
        self.rlock.Unlock()
        return
    }
    size = self.rlen
    r = limitRead{r: io.LimitReader(self.rwc,int64(size)),unlock: self.rlock.Unlock}
    return
}

//发送消息前先调用whead函数,来发送头部信息,然后发送body
func (self *connection) Write(p []byte) (n int,err error) {
    self.wlock.Lock()
    err = self.whead(len(p))
    if err != nil {
        self.wlock.Unlock()
        return
    }
    n,err = self.rwc.Write(p)
    self.wlock.Unlock()
    return
}

//发送一个流.必须指定流的长度
func (self *connection) Writer(size int,err error) {
    self.wlock.Lock()
    err = self.whead(int(size))
    if err != nil {
        self.wlock.Unlock()
        return
    }
    n,err = io.CopyN(self.rwc,r,int64(size))
    self.wlock.Unlock()
    return
}

func (self *connection) RemoteAddr() net.Addr {
    return self.rwc.RemoteAddr()
}

func (self *connection) LocalAddr() net.Addr {
    return self.rwc.LocalAddr()
}

func (self *connection) SetDeadline(t time.Time) error {
    return self.rwc.SetDeadline(t)
}

func (self *connection) SetReadDeadline(t time.Time) error {
    return self.rwc.SetReadDeadline(t)
}

func (self *connection) SetWriteDeadline(t time.Time) error {
    return self.rwc.SetWriteDeadline(t)
}

func (self *connection) Close() (err error) {
    err = self.rwc.Close()
    self.rlen = 0
    connPool.Put(self)
    return
}

type limitRead struct {
    r      io.Reader
    unlock func()
}

func (self limitRead) Read(p []byte) (n int,err error) {
    n,err = self.r.Read(p)
    if err != nil {
        self.unlock()
    }
    return n,err
}

测试函数方法:

package connection

import (
    "fmt"
    "io"
    "net"
    "os"
    "testing"
    "time"
)

func Test_conn(t *testing.T) {
    go Listener("tcp",":1789")
    time.Sleep(1e9)
    Dial()
}

func Dial() {
    conn,err := net.Dial("tcp","127.0.0.1:1789")
    if err != nil {
        fmt.Println(err)
        return
    }
    c := Newconnection(conn)
    defer c.Close()
    c.Write([]byte("Test"))
    c.Write([]byte("Test"))

    r,size,err := c.Read()
    if err != nil {
        fmt.Println(err,size)
        return
    }
    _,err = io.Copy(os.Stdout,r)
    if err != nil && err != io.EOF {
        fmt.Println(err)
    }
}

func Listener(proto,addr string) {
    lis,err := net.Listen(proto,addr)
    if err != nil {
        panic("Listen port error:" + err.Error())
        return
    }
    defer lis.Close()
    for {
        conn,err := lis.Accept()
        if err != nil {
            time.Sleep(1e7)
            continue
        }
        go handler(conn)
    }
}

func handler(conn net.Conn) {
    c := Newconnection(conn)
    msgchan := make(chan struct{})
    defer c.Close()
    go func(ch chan struct{}) {
        <-msgchan
        f,_ := os.Open("tcp_test.go")
        defer f.Close()
        info,_ := f.Stat()
        c.Writer(int(info.Size()),f)
        c.Close()
    }(msgchan)
    for {
        r,err := c.Read()
        if err != nil {
            fmt.Println(err)
            return
        }
        n,err := io.Copy(os.Stdout,r)
        if err != nil || n != int64(size) {
            if err == io.EOF {
                continue
            }
            fmt.Println("读取数据失败:",err)
            return
        }
        time.Sleep(2e9)
        msgchan <- struct{}{}
    }
}

猜你在找的Go相关文章