一个基于redis的rpc服务

设计思路

通信机制:
  • 使用Redis的List作为请求/响应队列
  • 使用Pub/Sub进行通知
工作流程:
  • 客户端将请求序列化后放入请求队列
  • 服务端监听请求队列,处理请求后将响应放入响应队列
  • 客户端监听对应的响应通道获取结果

实现代码

  1. RPC服务端
package redisrpc

import (
    "context"
    "encoding/json"
    "errors"
    "fmt"
    "log"
    "time"

    "github.com/go-redis/redis/v8"
)

// Server represents an RPC server
type Server struct {
    redisClient *redis.Client
    service     string
    handlers    map[string]func(params json.RawMessage) (interface{}, error)
    stopChan    chan struct{}
}

// NewServer creates a new RPC server
func NewServer(redisClient *redis.Client, service string) *Server {
    return &Server{
        redisClient: redisClient,
        service:     service,
        handlers:    make(map[string]func(params json.RawMessage) (interface{}, error)),
        stopChan:    make(chan struct{}),
    }
}

// RegisterHandler registers a method handler
func (s *Server) RegisterHandler(method string, handler func(params json.RawMessage) (interface{}, error)) {
    s.handlers[method] = handler
}

// Start starts the RPC server
func (s *Server) Start() error {
    requestQueue := fmt.Sprintf("rpc:%s:requests", s.service)
    
    go func() {
        for {
            select {
            case <-s.stopChan:
                return
            default:
                // BRPop with timeout to allow checking stopChan
                result, err := s.redisClient.BRPop(context.Background(), 1*time.Second, requestQueue).Result()
                if err != nil {
                    if err != redis.Nil {
                        log.Printf("Error reading from queue: %v", err)
                    }
                    continue
                }

                if len(result) < 2 {
                    continue
                }

                // result[0] is the key, result[1] is the value
                var req Request
                if err := json.Unmarshal([]byte(result[1]), &req); err != nil {
                    log.Printf("Error unmarshaling request: %v", err)
                    continue
                }

                handler, exists := s.handlers[req.Method]
                if !exists {
                    s.sendErrorResponse(req.ID, fmt.Sprintf("Method %s not found", req.Method))
                    continue
                }

                go func(req Request) {
                    result, err := handler(req.Params)
                    if err != nil {
                        s.sendErrorResponse(req.ID, err.Error())
                        return
                    }
                    s.sendResponse(req.ID, result)
                }(req)
            }
        }
    }()

    return nil
}

// Stop stops the RPC server
func (s *Server) Stop() {
    close(s.stopChan)
}

func (s *Server) sendResponse(id string, result interface{}) {
    responseQueue := fmt.Sprintf("rpc:%s:responses:%s", s.service, id)
    response := Response{
        ID:     id,
        Result: result,
        Error:  "",
    }

    data, err := json.Marshal(response)
    if err != nil {
        log.Printf("Error marshaling response: %v", err)
        return
    }

    if err := s.redisClient.Set(context.Background(), responseQueue, data, 10*time.Minute).Err(); err != nil {
        log.Printf("Error sending response: %v", err)
        return
    }

    // Notify client that response is ready
    notificationChannel := fmt.Sprintf("rpc:%s:notifications:%s", s.service, id)
    if err := s.redisClient.Publish(context.Background(), notificationChannel, "ready").Err(); err != nil {
        log.Printf("Error publishing notification: %v", err)
    }
}

func (s *Server) sendErrorResponse(id string, errorMsg string) {
    responseQueue := fmt.Sprintf("rpc:%s:responses:%s", s.service, id)
    response := Response{
        ID:     id,
        Result: nil,
        Error:  errorMsg,
    }

    data, err := json.Marshal(response)
    if err != nil {
        log.Printf("Error marshaling error response: %v", err)
        return
    }

    if err := s.redisClient.Set(context.Background(), responseQueue, data, 10*time.Minute).Err(); err != nil {
        log.Printf("Error sending error response: %v", err)
    }
}

2.RPC客户端

package redisrpc

import (
    "context"
    "encoding/json"
    "errors"
    "fmt"
    "time"

    "github.com/go-redis/redis/v8"
    "github.com/google/uuid"
)

// Client represents an RPC client
type Client struct {
    redisClient *redis.Client
    service     string
    timeout     time.Duration
}

// NewClient creates a new RPC client
func NewClient(redisClient *redis.Client, service string, timeout time.Duration) *Client {
    return &Client{
        redisClient: redisClient,
        service:     service,
        timeout:     timeout,
    }
}

// Call invokes a remote method
func (c *Client) Call(method string, params interface{}, result interface{}) error {
    requestID := uuid.New().String()
    requestQueue := fmt.Sprintf("rpc:%s:requests", c.service)
    responseQueue := fmt.Sprintf("rpc:%s:responses:%s", c.service, requestID)
    notificationChannel := fmt.Sprintf("rpc:%s:notifications:%s", c.service, requestID)

    // Subscribe to notification channel before sending request to avoid race condition
    pubsub := c.redisClient.Subscribe(context.Background(), notificationChannel)
    defer pubsub.Close()

    // Prepare request
    paramsData, err := json.Marshal(params)
    if err != nil {
        return fmt.Errorf("error marshaling params: %v", err)
    }

    request := Request{
        ID:     requestID,
        Method: method,
        Params: paramsData,
    }

    requestData, err := json.Marshal(request)
    if err != nil {
        return fmt.Errorf("error marshaling request: %v", err)
    }

    // Send request
    if err := c.redisClient.LPush(context.Background(), requestQueue, requestData).Err(); err != nil {
        return fmt.Errorf("error sending request: %v", err)
    }

    // Wait for notification or timeout
    ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
    defer cancel()

    ch := pubsub.Channel()
    select {
    case <-ch:
        // Response is ready
    case <-ctx.Done():
        return errors.New("request timed out")
    }

    // Get response
    responseData, err := c.redisClient.Get(ctx, responseQueue).Result()
    if err != nil {
        return fmt.Errorf("error getting response: %v", err)
    }

    var response Response
    if err := json.Unmarshal([]byte(responseData), &response); err != nil {
        return fmt.Errorf("error unmarshaling response: %v", err)
    }

    if response.Error != "" {
        return errors.New(response.Error)
    }

    if result != nil {
        resultData, err := json.Marshal(response.Result)
        if err != nil {
            return fmt.Errorf("error marshaling result: %v", err)
        }
        if err := json.Unmarshal(resultData, result); err != nil {
            return fmt.Errorf("error unmarshaling result: %v", err)
        }
    }

    return nil
}
  1. 定义请求和响应结构
package redisrpc

// Request represents an RPC request
type Request struct {
    ID     string          `json:"id"`
    Method string          `json:"method"`
    Params json.RawMessage `json:"params"`
}

// Response represents an RPC response
type Response struct {
    ID     string      `json:"id"`
    Result interface{} `json:"result,omitempty"`
    Error  string      `json:"error,omitempty"`
}

使用示例

服务端
package main

import (
    "encoding/json"
    "log"
    "redisrpc"

    "github.com/go-redis/redis/v8"
)

func main() {
    // Create Redis client
    rdb := redis.NewClient(&redis.Options{
        Addr:     "localhost:6379",
        Password: "", // no password set
        DB:       0,  // use default DB
    })

    // Create RPC server
    server := redisrpc.NewServer(rdb, "calculator")

    // Register methods
    server.RegisterHandler("add", func(params json.RawMessage) (interface{}, error) {
        var numbers []int
        if err := json.Unmarshal(params, &numbers); err != nil {
            return nil, err
        }
        if len(numbers) != 2 {
            return nil, errors.New("expected 2 numbers")
        }
        return numbers[0] + numbers[1], nil
    })

    server.RegisterHandler("multiply", func(params json.RawMessage) (interface{}, error) {
        var numbers []int
        if err := json.Unmarshal(params, &numbers); err != nil {
            return nil, err
        }
        if len(numbers) != 2 {
            return nil, errors.New("expected 2 numbers")
        }
        return numbers[0] * numbers[1], nil
    })

    // Start server
    if err := server.Start(); err != nil {
        log.Fatal(err)
    }
    defer server.Stop()

    log.Println("RPC server started. Press Ctrl+C to stop.")
    select {} // Block forever
}
客户端
package main

import (
    "log"
    "redisrpc"

    "github.com/go-redis/redis/v8"
)

func main() {
    // Create Redis client
    rdb := redis.NewClient(&redis.Options{
        Addr:     "localhost:6379",
        Password: "", // no password set
        DB:       0,  // use default DB
    })

    // Create RPC client
    client := redisrpc.NewClient(rdb, "calculator", 5*time.Second)

    // Call add method
    var sum int
    if err := client.Call("add", []int{5, 3}, &sum); err != nil {
        log.Fatal(err)
    }
    log.Printf("5 + 3 = %d", sum)

    // Call multiply method
    var product int
    if err := client.Call("multiply", []int{5, 3}, &product); err != nil {
        log.Fatal(err)
    }
    log.Printf("5 * 3 = %d", product)
}
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容