blob: a7900bb13ad5d2b0e69144150e91e55169047939 [file] [log] [blame]
// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package rpctype
import (
"compress/flate"
"fmt"
"io"
"net"
"net/rpc"
"os"
"time"
"github.com/google/syzkaller/pkg/log"
)
type RPCServer struct {
ln net.Listener
s *rpc.Server
}
func NewRPCServer(addr, name string, receiver interface{}) (*RPCServer, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %v: %v", addr, err)
}
s := rpc.NewServer()
if err := s.RegisterName(name, receiver); err != nil {
return nil, err
}
serv := &RPCServer{
ln: ln,
s: s,
}
return serv, nil
}
func (serv *RPCServer) Serve() {
for {
conn, err := serv.ln.Accept()
if err != nil {
log.Logf(0, "failed to accept an rpc connection: %v", err)
continue
}
setupKeepAlive(conn, 10*time.Second)
go serv.s.ServeConn(newFlateConn(conn))
}
}
func (serv *RPCServer) Addr() net.Addr {
return serv.ln.Addr()
}
type RPCClient struct {
conn net.Conn
c *rpc.Client
}
func Dial(addr string) (net.Conn, error) {
var conn net.Conn
var err error
if addr == "stdin" {
// This is used by vm/gvisor which passes us a unix socket connection in stdin.
return net.FileConn(os.Stdin)
}
if conn, err = net.DialTimeout("tcp", addr, 60*time.Second); err != nil {
return nil, err
}
setupKeepAlive(conn, time.Minute)
return conn, nil
}
func NewRPCClient(addr string) (*RPCClient, error) {
conn, err := Dial(addr)
if err != nil {
return nil, err
}
cli := &RPCClient{
conn: conn,
c: rpc.NewClient(newFlateConn(conn)),
}
return cli, nil
}
func (cli *RPCClient) Call(method string, args, reply interface{}) error {
// Note: SetDeadline is not implemented on fuchsia, so don't fail on error.
cli.conn.SetDeadline(time.Now().Add(5 * 60 * time.Second))
defer cli.conn.SetDeadline(time.Time{})
return cli.c.Call(method, args, reply)
}
func (cli *RPCClient) Close() {
cli.c.Close()
}
func RPCCall(addr, method string, args, reply interface{}) error {
c, err := NewRPCClient(addr)
if err != nil {
return err
}
defer c.Close()
return c.Call(method, args, reply)
}
func setupKeepAlive(conn net.Conn, keepAlive time.Duration) {
conn.(*net.TCPConn).SetKeepAlive(true)
conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive)
}
// flateConn wraps net.Conn in flate.Reader/Writer for compressed traffic.
type flateConn struct {
r io.ReadCloser
w *flate.Writer
c io.Closer
}
func newFlateConn(conn io.ReadWriteCloser) io.ReadWriteCloser {
w, err := flate.NewWriter(conn, 9)
if err != nil {
panic(err)
}
return &flateConn{
r: flate.NewReader(conn),
w: w,
c: conn,
}
}
func (fc *flateConn) Read(data []byte) (int, error) {
return fc.r.Read(data)
}
func (fc *flateConn) Write(data []byte) (int, error) {
n, err := fc.w.Write(data)
if err != nil {
return n, err
}
if err := fc.w.Flush(); err != nil {
return n, err
}
return n, nil
}
func (fc *flateConn) Close() error {
var err0 error
if err := fc.r.Close(); err != nil {
err0 = err
}
if err := fc.w.Close(); err != nil {
err0 = err
}
if err := fc.c.Close(); err != nil {
err0 = err
}
return err0
}