设计思路
通信机制:
- 使用Redis的List作为请求/响应队列
- 使用Pub/Sub进行通知
工作流程:
- 客户端将请求序列化后放入请求队列
- 服务端监听请求队列,处理请求后将响应放入响应队列
- 客户端监听对应的响应通道获取结果
实现代码
- RPC服务端
package redisrpc
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"time"
"github.com/go-redis/redis/v8"
)
// Server represents an RPC server
type Server struct {
redisClient *redis.Client
service string
handlers map[string]func(params json.RawMessage) (interface{}, error)
stopChan chan struct{}
}
// NewServer creates a new RPC server
func NewServer(redisClient *redis.Client, service string) *Server {
return &Server{
redisClient: redisClient,
service: service,
handlers: make(map[string]func(params json.RawMessage) (interface{}, error)),
stopChan: make(chan struct{}),
}
}
// RegisterHandler registers a method handler
func (s *Server) RegisterHandler(method string, handler func(params json.RawMessage) (interface{}, error)) {
s.handlers[method] = handler
}
// Start starts the RPC server
func (s *Server) Start() error {
requestQueue := fmt.Sprintf("rpc:%s:requests", s.service)
go func() {
for {
select {
case <-s.stopChan:
return
default:
// BRPop with timeout to allow checking stopChan
result, err := s.redisClient.BRPop(context.Background(), 1*time.Second, requestQueue).Result()
if err != nil {
if err != redis.Nil {
log.Printf("Error reading from queue: %v", err)
}
continue
}
if len(result) < 2 {
continue
}
// result[0] is the key, result[1] is the value
var req Request
if err := json.Unmarshal([]byte(result[1]), &req); err != nil {
log.Printf("Error unmarshaling request: %v", err)
continue
}
handler, exists := s.handlers[req.Method]
if !exists {
s.sendErrorResponse(req.ID, fmt.Sprintf("Method %s not found", req.Method))
continue
}
go func(req Request) {
result, err := handler(req.Params)
if err != nil {
s.sendErrorResponse(req.ID, err.Error())
return
}
s.sendResponse(req.ID, result)
}(req)
}
}
}()
return nil
}
// Stop stops the RPC server
func (s *Server) Stop() {
close(s.stopChan)
}
func (s *Server) sendResponse(id string, result interface{}) {
responseQueue := fmt.Sprintf("rpc:%s:responses:%s", s.service, id)
response := Response{
ID: id,
Result: result,
Error: "",
}
data, err := json.Marshal(response)
if err != nil {
log.Printf("Error marshaling response: %v", err)
return
}
if err := s.redisClient.Set(context.Background(), responseQueue, data, 10*time.Minute).Err(); err != nil {
log.Printf("Error sending response: %v", err)
return
}
// Notify client that response is ready
notificationChannel := fmt.Sprintf("rpc:%s:notifications:%s", s.service, id)
if err := s.redisClient.Publish(context.Background(), notificationChannel, "ready").Err(); err != nil {
log.Printf("Error publishing notification: %v", err)
}
}
func (s *Server) sendErrorResponse(id string, errorMsg string) {
responseQueue := fmt.Sprintf("rpc:%s:responses:%s", s.service, id)
response := Response{
ID: id,
Result: nil,
Error: errorMsg,
}
data, err := json.Marshal(response)
if err != nil {
log.Printf("Error marshaling error response: %v", err)
return
}
if err := s.redisClient.Set(context.Background(), responseQueue, data, 10*time.Minute).Err(); err != nil {
log.Printf("Error sending error response: %v", err)
}
}
2.RPC客户端
package redisrpc
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-redis/redis/v8"
"github.com/google/uuid"
)
// Client represents an RPC client
type Client struct {
redisClient *redis.Client
service string
timeout time.Duration
}
// NewClient creates a new RPC client
func NewClient(redisClient *redis.Client, service string, timeout time.Duration) *Client {
return &Client{
redisClient: redisClient,
service: service,
timeout: timeout,
}
}
// Call invokes a remote method
func (c *Client) Call(method string, params interface{}, result interface{}) error {
requestID := uuid.New().String()
requestQueue := fmt.Sprintf("rpc:%s:requests", c.service)
responseQueue := fmt.Sprintf("rpc:%s:responses:%s", c.service, requestID)
notificationChannel := fmt.Sprintf("rpc:%s:notifications:%s", c.service, requestID)
// Subscribe to notification channel before sending request to avoid race condition
pubsub := c.redisClient.Subscribe(context.Background(), notificationChannel)
defer pubsub.Close()
// Prepare request
paramsData, err := json.Marshal(params)
if err != nil {
return fmt.Errorf("error marshaling params: %v", err)
}
request := Request{
ID: requestID,
Method: method,
Params: paramsData,
}
requestData, err := json.Marshal(request)
if err != nil {
return fmt.Errorf("error marshaling request: %v", err)
}
// Send request
if err := c.redisClient.LPush(context.Background(), requestQueue, requestData).Err(); err != nil {
return fmt.Errorf("error sending request: %v", err)
}
// Wait for notification or timeout
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
ch := pubsub.Channel()
select {
case <-ch:
// Response is ready
case <-ctx.Done():
return errors.New("request timed out")
}
// Get response
responseData, err := c.redisClient.Get(ctx, responseQueue).Result()
if err != nil {
return fmt.Errorf("error getting response: %v", err)
}
var response Response
if err := json.Unmarshal([]byte(responseData), &response); err != nil {
return fmt.Errorf("error unmarshaling response: %v", err)
}
if response.Error != "" {
return errors.New(response.Error)
}
if result != nil {
resultData, err := json.Marshal(response.Result)
if err != nil {
return fmt.Errorf("error marshaling result: %v", err)
}
if err := json.Unmarshal(resultData, result); err != nil {
return fmt.Errorf("error unmarshaling result: %v", err)
}
}
return nil
}
- 定义请求和响应结构
package redisrpc
// Request represents an RPC request
type Request struct {
ID string `json:"id"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
}
// Response represents an RPC response
type Response struct {
ID string `json:"id"`
Result interface{} `json:"result,omitempty"`
Error string `json:"error,omitempty"`
}
使用示例
服务端
package main
import (
"encoding/json"
"log"
"redisrpc"
"github.com/go-redis/redis/v8"
)
func main() {
// Create Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password set
DB: 0, // use default DB
})
// Create RPC server
server := redisrpc.NewServer(rdb, "calculator")
// Register methods
server.RegisterHandler("add", func(params json.RawMessage) (interface{}, error) {
var numbers []int
if err := json.Unmarshal(params, &numbers); err != nil {
return nil, err
}
if len(numbers) != 2 {
return nil, errors.New("expected 2 numbers")
}
return numbers[0] + numbers[1], nil
})
server.RegisterHandler("multiply", func(params json.RawMessage) (interface{}, error) {
var numbers []int
if err := json.Unmarshal(params, &numbers); err != nil {
return nil, err
}
if len(numbers) != 2 {
return nil, errors.New("expected 2 numbers")
}
return numbers[0] * numbers[1], nil
})
// Start server
if err := server.Start(); err != nil {
log.Fatal(err)
}
defer server.Stop()
log.Println("RPC server started. Press Ctrl+C to stop.")
select {} // Block forever
}
客户端
package main
import (
"log"
"redisrpc"
"github.com/go-redis/redis/v8"
)
func main() {
// Create Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password set
DB: 0, // use default DB
})
// Create RPC client
client := redisrpc.NewClient(rdb, "calculator", 5*time.Second)
// Call add method
var sum int
if err := client.Call("add", []int{5, 3}, &sum); err != nil {
log.Fatal(err)
}
log.Printf("5 + 3 = %d", sum)
// Call multiply method
var product int
if err := client.Call("multiply", []int{5, 3}, &product); err != nil {
log.Fatal(err)
}
log.Printf("5 * 3 = %d", product)
}