diff --git a/build_protos.sh b/build_protos.sh new file mode 100644 index 000000000..0a21e13f4 --- /dev/null +++ b/build_protos.sh @@ -0,0 +1 @@ +protoc --go_out=. ./internal/wal/wal.proto \ No newline at end of file diff --git a/config/config.go b/config/config.go index 9809c135f..e866aeab9 100644 --- a/config/config.go +++ b/config/config.go @@ -55,6 +55,11 @@ var ( EnableProfiling = false EnableWatch = true + LogDir = "" + + EnableWAL = true + RestoreFromWAL = false + WALEngine = "sqlite" ) type Config struct { diff --git a/docs/src/content/docs/commands/GETDEL.md b/docs/src/content/docs/commands/GETDEL.md index 29930ad42..f08f87dee 100644 --- a/docs/src/content/docs/commands/GETDEL.md +++ b/docs/src/content/docs/commands/GETDEL.md @@ -88,7 +88,6 @@ Setting a key `mylist` as a list and then trying to use `GETDEL`, which is incom 127.0.0.1:7379> LPUSH mylist "item1" (integer) 1 127.0.0.1:7379> GETDEL mylist -<<<<<<< HEAD ERROR WRONGTYPE Operation against a key holding the wrong kind of value ``` @@ -97,7 +96,3 @@ ERROR WRONGTYPE Operation against a key holding the wrong kind of value - The key `mylist` is a list, not a string. - # The `GETDEL` command raises a `WRONGTYPE` error because it expects the key to be a string. (error) WRONGTYPE Operation against a key holding the wrong kind of value - -``` ->>>>>>> d43577926873d0df0c8f189cdde6afa65c515ccb -``` diff --git a/go.mod b/go.mod index 8d93328fb..9791cfecd 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 + github.com/mattn/go-sqlite3 v1.14.24 github.com/mmcloughlin/geohash v0.10.0 github.com/ohler55/ojg v1.25.0 github.com/pelletier/go-toml/v2 v2.2.3 @@ -58,4 +59,5 @@ require ( github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 golang.org/x/crypto v0.28.0 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c + google.golang.org/protobuf v1.35.1 ) diff --git a/go.sum b/go.sum index aefcf136c..2b7c9f001 100644 --- a/go.sum +++ b/go.sum @@ -65,6 +65,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mmcloughlin/geohash v0.10.0 h1:9w1HchfDfdeLc+jFEf/04D27KP7E2QmpDu52wPbJWRE= @@ -131,6 +133,8 @@ golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= +google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/integration_tests/commands/async/setup.go b/integration_tests/commands/async/setup.go index cfdcaebdc..b8a3ea684 100644 --- a/integration_tests/commands/async/setup.go +++ b/integration_tests/commands/async/setup.go @@ -123,7 +123,7 @@ func RunTestServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption gec := make(chan error) shardManager := shard.NewShardManager(1, watchChan, nil, gec) // Initialize the AsyncServer - testServer := server.NewAsyncServer(shardManager, watchChan) + testServer := server.NewAsyncServer(shardManager, watchChan, nil) // Try to bind to a port with a maximum of `totalRetries` retries. for i := 0; i < totalRetries; i++ { diff --git a/integration_tests/commands/http/setup.go b/integration_tests/commands/http/setup.go index a38d2b075..62bd81b5d 100644 --- a/integration_tests/commands/http/setup.go +++ b/integration_tests/commands/http/setup.go @@ -108,7 +108,7 @@ func RunHTTPServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerOption queryWatcherLocal := querymanager.NewQueryManager() config.HTTPPort = opt.Port // Initialize the HTTPServer - testServer := server.NewHTTPServer(shardManager) + testServer := server.NewHTTPServer(shardManager, nil) // Inform the user that the server is starting fmt.Println("Starting the test server on port", config.HTTPPort) shardManagerCtx, cancelShardManager := context.WithCancel(ctx) diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go index 7026134eb..a1293a8c3 100644 --- a/integration_tests/commands/resp/setup.go +++ b/integration_tests/commands/resp/setup.go @@ -128,7 +128,7 @@ func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) { shardManager := shard.NewShardManager(1, queryWatchChan, cmdWatchChan, gec) workerManager := worker.NewWorkerManager(20000, shardManager) // Initialize the RESP Server - testServer := resp.NewServer(shardManager, workerManager, cmdWatchSubscriptionChan, cmdWatchChan, gec) + testServer := resp.NewServer(shardManager, workerManager, cmdWatchSubscriptionChan, cmdWatchChan, gec, nil) ctx, cancel := context.WithCancel(context.Background()) fmt.Println("Starting the test server on port", config.DiceConfig.AsyncServer.Port) diff --git a/integration_tests/commands/websocket/setup.go b/integration_tests/commands/websocket/setup.go index cfb748afb..b75b17c4f 100644 --- a/integration_tests/commands/websocket/setup.go +++ b/integration_tests/commands/websocket/setup.go @@ -109,7 +109,7 @@ func RunWebsocketServer(ctx context.Context, wg *sync.WaitGroup, opt TestServerO shardManager := shard.NewShardManager(1, watchChan, nil, globalErrChannel) queryWatcherLocal := querymanager.NewQueryManager() config.WebsocketPort = opt.Port - testServer := server.NewWebSocketServer(shardManager, testPort1) + testServer := server.NewWebSocketServer(shardManager, testPort1, nil) shardManagerCtx, cancelShardManager := context.WithCancel(ctx) // run shard manager diff --git a/internal/cmd/cmds.go b/internal/cmd/cmds.go index 68143983a..a561764be 100644 --- a/internal/cmd/cmds.go +++ b/internal/cmd/cmds.go @@ -17,9 +17,14 @@ type RedisCmds struct { RequestID uint32 } +// Repr returns a string representation of the command. +func (cmd *DiceDBCmd) Repr() string { + return fmt.Sprintf("%s %s", cmd.Cmd, strings.Join(cmd.Args, " ")) +} + // GetFingerprint returns a 32-bit fingerprint of the command and its arguments. func (cmd *DiceDBCmd) GetFingerprint() uint32 { - return farm.Fingerprint32([]byte(fmt.Sprintf("%s-%s", cmd.Cmd, strings.Join(cmd.Args, " ")))) + return farm.Fingerprint32([]byte(cmd.Repr())) } // GetKey Returns the key which the command operates on. diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index c748be162..2a993c9a0 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -13,8 +13,8 @@ import ( "time" "github.com/dicedb/dice/internal/eval" - "github.com/dicedb/dice/internal/server/abstractserver" + "github.com/dicedb/dice/internal/wal" "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio" @@ -61,7 +61,7 @@ func (cim *CaseInsensitiveMux) ServeHTTP(w http.ResponseWriter, r *http.Request) cim.mux.ServeHTTP(w, r) } -func NewHTTPServer(shardManager *shard.ShardManager) *HTTPServer { +func NewHTTPServer(shardManager *shard.ShardManager, wl wal.AbstractWAL) *HTTPServer { mux := http.NewServeMux() caseInsensitiveMux := &CaseInsensitiveMux{mux: mux} srv := &http.Server{ diff --git a/internal/server/resp/server.go b/internal/server/resp/server.go index 1fb8b86c0..b0834e2a8 100644 --- a/internal/server/resp/server.go +++ b/internal/server/resp/server.go @@ -12,6 +12,7 @@ import ( "time" "github.com/dicedb/dice/internal/server/abstractserver" + "github.com/dicedb/dice/internal/wal" dstore "github.com/dicedb/dice/internal/store" "github.com/dicedb/dice/internal/watchmanager" @@ -49,20 +50,21 @@ type Server struct { cmdWatchSubscriptionChan chan watchmanager.WatchSubscription cmdWatchChan chan dstore.CmdWatchEvent globalErrorChan chan error + wl wal.AbstractWAL } -func NewServer(shardManager *shard.ShardManager, workerManager *worker.WorkerManager, cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, - cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error) *Server { +func NewServer(shardManager *shard.ShardManager, workerManager *worker.WorkerManager, + cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, cmdWatchChan chan dstore.CmdWatchEvent, globalErrChan chan error, wl wal.AbstractWAL) *Server { return &Server{ - Host: config.DiceConfig.AsyncServer.Addr, - Port: config.DiceConfig.AsyncServer.Port, - connBacklogSize: DefaultConnBacklogSize, - workerManager: workerManager, - shardManager: shardManager, - watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan), - cmdWatchChan: cmdWatchChan, - cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, - globalErrorChan: globalErrChan, + Host: config.DiceConfig.AsyncServer.Addr, + Port: config.DiceConfig.AsyncServer.Port, + connBacklogSize: DefaultConnBacklogSize, + workerManager: workerManager, + shardManager: shardManager, + watchManager: watchmanager.NewManager(cmdWatchSubscriptionChan), + cmdWatchChan: cmdWatchChan, + globalErrorChan: globalErrChan, + wl: wl, } } @@ -198,7 +200,7 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou preprocessingChan := make(chan *ops.StoreResponse) // preprocessingChan is specifically for handling responses from shards for commands that require preprocessing wID := GenerateUniqueWorkerID() - w := worker.NewWorker(wID, responseChan, preprocessingChan, s.cmdWatchSubscriptionChan, ioHandler, parser, s.shardManager, s.globalErrorChan) + w := worker.NewWorker(wID, responseChan, preprocessingChan, s.cmdWatchSubscriptionChan, ioHandler, parser, s.shardManager, s.globalErrorChan, s.wl) // Register the worker with the worker manager err = s.workerManager.RegisterWorker(w) diff --git a/internal/server/server.go b/internal/server/server.go index cf9b98ca1..3b27efa79 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,6 +14,7 @@ import ( "time" "github.com/dicedb/dice/internal/server/abstractserver" + "github.com/dicedb/dice/internal/wal" "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" @@ -44,7 +45,7 @@ type AsyncServer struct { } // NewAsyncServer initializes a new AsyncServer -func NewAsyncServer(shardManager *shard.ShardManager, queryWatchChan chan dstore.QueryWatchEvent) *AsyncServer { +func NewAsyncServer(shardManager *shard.ShardManager, queryWatchChan chan dstore.QueryWatchEvent, wl wal.AbstractWAL) *AsyncServer { return &AsyncServer{ maxClients: config.DiceConfig.Performance.MaxClients, connectedClients: make(map[int]*comm.Client), diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index 7d30f60dc..0f3f6d75d 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -15,6 +15,7 @@ import ( "time" "github.com/dicedb/dice/internal/server/abstractserver" + "github.com/dicedb/dice/internal/wal" "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio" @@ -46,7 +47,7 @@ type WebsocketServer struct { shutdownChan chan struct{} } -func NewWebSocketServer(shardManager *shard.ShardManager, port int) *WebsocketServer { +func NewWebSocketServer(shardManager *shard.ShardManager, port int, wl wal.AbstractWAL) *WebsocketServer { mux := http.NewServeMux() srv := &http.Server{ Addr: fmt.Sprintf(":%d", port), diff --git a/internal/wal/wal.go b/internal/wal/wal.go new file mode 100644 index 000000000..ca251560e --- /dev/null +++ b/internal/wal/wal.go @@ -0,0 +1,72 @@ +package wal + +import ( + "fmt" + "log/slog" + sync "sync" + "time" + + "github.com/dicedb/dice/internal/cmd" +) + +type AbstractWAL interface { + LogCommand(c *cmd.DiceDBCmd) + Close() error + Init(t time.Time) error + ForEachCommand(f func(c cmd.DiceDBCmd) error) error +} + +var ( + ticker *time.Ticker + stopCh chan struct{} + mu sync.Mutex +) + +func init() { + ticker = time.NewTicker(1 * time.Minute) + stopCh = make(chan struct{}) +} + +func rotateWAL(wl AbstractWAL) { + mu.Lock() + defer mu.Unlock() + + if err := wl.Close(); err != nil { + slog.Warn("error closing the WAL", slog.Any("error", err)) + } + + if err := wl.Init(time.Now()); err != nil { + slog.Warn("error creating a new WAL", slog.Any("error", err)) + } +} + +func periodicRotate(wl AbstractWAL) { + for { + select { + case <-ticker.C: + rotateWAL(wl) + case <-stopCh: + return + } + } +} + +func InitBG(wl AbstractWAL) { + go periodicRotate(wl) +} + +func ShutdownBG() { + close(stopCh) + ticker.Stop() +} + +func ReplayWAL(wl AbstractWAL) { + err := wl.ForEachCommand(func(c cmd.DiceDBCmd) error { + fmt.Println("replaying", c.Cmd, c.Args) + return nil + }) + + if err != nil { + slog.Warn("error replaying WAL", slog.Any("error", err)) + } +} diff --git a/internal/wal/wal.pb.go b/internal/wal/wal.pb.go new file mode 100644 index 000000000..3adf5277b --- /dev/null +++ b/internal/wal/wal.pb.go @@ -0,0 +1,137 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.35.1 +// protoc v3.12.4 +// source: internal/wal/wal.proto + +package wal + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// WALLogEntry represents a single log entry in the WAL. +type WALLogEntry struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Checksum []byte `protobuf:"bytes,1,opt,name=checksum,proto3" json:"checksum,omitempty"` // SHA-256 checksum of the command for integrity + Command string `protobuf:"bytes,2,opt,name=command,proto3" json:"command,omitempty"` // Command +} + +func (x *WALLogEntry) Reset() { + *x = WALLogEntry{} + mi := &file_internal_wal_wal_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WALLogEntry) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WALLogEntry) ProtoMessage() {} + +func (x *WALLogEntry) ProtoReflect() protoreflect.Message { + mi := &file_internal_wal_wal_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WALLogEntry.ProtoReflect.Descriptor instead. +func (*WALLogEntry) Descriptor() ([]byte, []int) { + return file_internal_wal_wal_proto_rawDescGZIP(), []int{0} +} + +func (x *WALLogEntry) GetChecksum() []byte { + if x != nil { + return x.Checksum + } + return nil +} + +func (x *WALLogEntry) GetCommand() string { + if x != nil { + return x.Command + } + return "" +} + +var File_internal_wal_wal_proto protoreflect.FileDescriptor + +var file_internal_wal_wal_proto_rawDesc = []byte{ + 0x0a, 0x16, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x77, 0x61, 0x6c, 0x2f, 0x77, + 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x77, 0x61, 0x6c, 0x22, 0x43, 0x0a, + 0x0b, 0x57, 0x41, 0x4c, 0x4c, 0x6f, 0x67, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x1a, 0x0a, 0x08, + 0x63, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x75, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, + 0x63, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x75, 0x6d, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6d, 0x6d, + 0x61, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6d, 0x6d, 0x61, + 0x6e, 0x64, 0x42, 0x0e, 0x5a, 0x0c, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x77, + 0x61, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_internal_wal_wal_proto_rawDescOnce sync.Once + file_internal_wal_wal_proto_rawDescData = file_internal_wal_wal_proto_rawDesc +) + +func file_internal_wal_wal_proto_rawDescGZIP() []byte { + file_internal_wal_wal_proto_rawDescOnce.Do(func() { + file_internal_wal_wal_proto_rawDescData = protoimpl.X.CompressGZIP(file_internal_wal_wal_proto_rawDescData) + }) + return file_internal_wal_wal_proto_rawDescData +} + +var file_internal_wal_wal_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_internal_wal_wal_proto_goTypes = []any{ + (*WALLogEntry)(nil), // 0: wal.WALLogEntry +} +var file_internal_wal_wal_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_internal_wal_wal_proto_init() } +func file_internal_wal_wal_proto_init() { + if File_internal_wal_wal_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_internal_wal_wal_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_internal_wal_wal_proto_goTypes, + DependencyIndexes: file_internal_wal_wal_proto_depIdxs, + MessageInfos: file_internal_wal_wal_proto_msgTypes, + }.Build() + File_internal_wal_wal_proto = out.File + file_internal_wal_wal_proto_rawDesc = nil + file_internal_wal_wal_proto_goTypes = nil + file_internal_wal_wal_proto_depIdxs = nil +} diff --git a/internal/wal/wal.proto b/internal/wal/wal.proto new file mode 100644 index 000000000..e935d665d --- /dev/null +++ b/internal/wal/wal.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package wal; +option go_package = "internal/wal"; + +// WALLogEntry represents a single log entry in the WAL. +message WALLogEntry { + bytes checksum = 1; // SHA-256 checksum of the command for integrity + string command = 2; // Command +} diff --git a/internal/wal/wal_aof.go b/internal/wal/wal_aof.go new file mode 100644 index 000000000..ab690a097 --- /dev/null +++ b/internal/wal/wal_aof.go @@ -0,0 +1,182 @@ +package wal + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "github.com/dicedb/dice/internal/cmd" + "google.golang.org/protobuf/proto" +) + +var writeBuf bytes.Buffer + +type WALAOF struct { + file *os.File + mutex sync.Mutex + logDir string +} + +func NewAOFWAL(logDir string) (*WALAOF, error) { + return &WALAOF{ + logDir: logDir, + }, nil +} + +func (w *WALAOF) Init(t time.Time) error { + slog.Debug("initializing WAL at", slog.Any("log-dir", w.logDir)) + if err := os.MkdirAll(w.logDir, os.ModePerm); err != nil { + return fmt.Errorf("failed to create log directory: %w", err) + } + + timestamp := t.Format("20060102_1504") + path := filepath.Join(w.logDir, fmt.Sprintf("wal_%s.aof", timestamp)) + + newFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to open new WAL file: %v", err) + } + + w.file = newFile + return nil +} + +// LogCommand serializes a WALLogEntry and writes it to the current WAL file. +func (w *WALAOF) LogCommand(c *cmd.DiceDBCmd) { + w.mutex.Lock() + defer w.mutex.Unlock() + + repr := fmt.Sprintf("%s %s", c.Cmd, strings.Join(c.Args, " ")) + + entry := &WALLogEntry{ + Command: repr, + Checksum: checksum(repr), + } + + data, err := proto.Marshal(entry) + if err != nil { + slog.Warn("failed to serialize command", slog.Any("error", err.Error())) + } + + writeBuf.Reset() + writeBuf.Grow(4 + len(data)) + if binary.Write(&writeBuf, binary.BigEndian, uint32(len(data))) != nil { + slog.Warn("failed to write entry length to WAL", slog.Any("error", err.Error())) + } + writeBuf.Write(data) + + if _, err := w.file.Write(writeBuf.Bytes()); err != nil { + slog.Warn("failed to write serialized command to WAL", slog.Any("error", err.Error())) + } + + if err := w.file.Sync(); err != nil { + slog.Warn("failed to sync WAL", slog.Any("error", err.Error())) + } + + slog.Debug("logged command in WAL", slog.Any("command", c.Repr())) +} + +func (w *WALAOF) Close() error { + if w.file == nil { + return nil + } + return w.file.Close() +} + +// checksum generates a SHA-256 hash for the given command. +func checksum(command string) []byte { + hash := sha256.Sum256([]byte(command)) + return hash[:] +} + +func (w *WALAOF) ForEachCommand(f func(c cmd.DiceDBCmd) error) error { + var length uint32 + + files, err := os.ReadDir(w.logDir) + if err != nil { + return fmt.Errorf("failed to read log directory: %v", err) + } + + var walFiles []os.DirEntry + + for _, file := range files { + if !file.IsDir() && filepath.Ext(file.Name()) == ".aof" { + walFiles = append(walFiles, file) + } + } + + if len(walFiles) == 0 { + return fmt.Errorf("no valid WAL files found in log directory") + } + + // Sort files by timestamp in ascending order + sort.Slice(walFiles, func(i, j int) bool { + timestampStrI := walFiles[i].Name()[4:17] + timestampStrJ := walFiles[j].Name()[4:17] + timestampI, errI := time.Parse("20060102_1504", timestampStrI) + timestampJ, errJ := time.Parse("20060102_1504", timestampStrJ) + if errI != nil || errJ != nil { + return false + } + return timestampI.Before(timestampJ) + }) + + for _, file := range walFiles { + filePath := filepath.Join(w.logDir, file.Name()) + + slog.Debug("loading WAL", slog.Any("file", filePath)) + + file, err := os.OpenFile(filePath, os.O_RDONLY, 0644) + if err != nil { + return fmt.Errorf("failed to open WAL file %s: %v", file.Name(), err) + } + + for { + if err := binary.Read(file, binary.BigEndian, &length); err != nil { + if err == io.EOF { + break + } + return fmt.Errorf("failed to read entry length: %v", err) + } + + // TODO: Optimize this allocation. + // Pre-allocate and reuse rather than allocating for each entry. + readBufBytes := make([]byte, length) + if _, err := io.ReadFull(file, readBufBytes); err != nil { + return fmt.Errorf("failed to read entry data: %v", err) + } + + entry := &WALLogEntry{} + if err := proto.Unmarshal(readBufBytes, entry); err != nil { + return fmt.Errorf("failed to unmarshal WAL entry: %v", err) + } + + commandParts := strings.SplitN(entry.Command, " ", 2) + if len(commandParts) < 2 { + return fmt.Errorf("invalid command format in WAL entry: %s", entry.Command) + } + + c := cmd.DiceDBCmd{ + Cmd: commandParts[0], + Args: strings.Split(commandParts[1], " "), + } + + if err := f(c); err != nil { + return fmt.Errorf("error processing command: %v", err) + } + } + + file.Close() + } + + return nil +} diff --git a/internal/wal/wal_null.go b/internal/wal/wal_null.go new file mode 100644 index 000000000..c4a58ccc9 --- /dev/null +++ b/internal/wal/wal_null.go @@ -0,0 +1,30 @@ +package wal + +import ( + "time" + + "github.com/dicedb/dice/internal/cmd" +) + +type WALNull struct { +} + +func NewNullWAL() (*WALNull, error) { + return &WALNull{}, nil +} + +func (w *WALNull) Init(t time.Time) error { + return nil +} + +// LogCommand serializes a WALLogEntry and writes it to the current WAL file. +func (w *WALNull) LogCommand(c *cmd.DiceDBCmd) { +} + +func (w *WALNull) Close() error { + return nil +} + +func (w *WALNull) ForEachCommand(f func(c cmd.DiceDBCmd) error) error { + return nil +} diff --git a/internal/wal/wal_sqlite.go b/internal/wal/wal_sqlite.go new file mode 100644 index 000000000..5805667d9 --- /dev/null +++ b/internal/wal/wal_sqlite.go @@ -0,0 +1,147 @@ +package wal + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + sync "sync" + "time" + + "log/slog" + + "github.com/dicedb/dice/internal/cmd" + _ "github.com/mattn/go-sqlite3" +) + +type WALSQLite struct { + logDir string + curDB *sql.DB + mu sync.Mutex +} + +func NewSQLiteWAL(logDir string) (*WALSQLite, error) { + return &WALSQLite{ + logDir: logDir, + }, nil +} + +func (w *WALSQLite) Init(t time.Time) error { + slog.Debug("initializing WAL at", slog.Any("log-dir", w.logDir)) + if err := os.MkdirAll(w.logDir, os.ModePerm); err != nil { + return fmt.Errorf("failed to create log directory: %w", err) + } + + timestamp := t.Format("20060102_1504") + path := filepath.Join(w.logDir, fmt.Sprintf("wal_%s.sqlite3", timestamp)) + + db, err := sql.Open("sqlite3", path) + if err != nil { + return err + } + + _, err = db.Exec("PRAGMA journal_mode=WAL;") + if err != nil { + return err + } + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS wal ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + command TEXT NOT NULL + );`) + if err != nil { + return err + } + + w.curDB = db + return nil +} + +func (w *WALSQLite) LogCommand(c *cmd.DiceDBCmd) { + w.mu.Lock() + defer w.mu.Unlock() + + if _, err := w.curDB.Exec("INSERT INTO wal (command) VALUES (?)", c.Repr()); err != nil { + slog.Error("failed to log command in WAL", slog.Any("error", err)) + } else { + slog.Debug("logged command in WAL", slog.Any("command", c.Repr())) + } +} + +func (w *WALSQLite) Close() error { + return w.curDB.Close() +} + +func (w *WALSQLite) ForEachCommand(f func(c cmd.DiceDBCmd) error) error { + files, err := os.ReadDir(w.logDir) + if err != nil { + return fmt.Errorf("failed to read log directory: %v", err) + } + + var walFiles []os.DirEntry + + for _, file := range files { + if !file.IsDir() && filepath.Ext(file.Name()) == ".sqlite3" { + walFiles = append(walFiles, file) + } + } + + if len(walFiles) == 0 { + return fmt.Errorf("no valid WAL files found in log directory") + } + + // Sort files by timestamp in ascending order + sort.Slice(walFiles, func(i, j int) bool { + timestampStrI := walFiles[i].Name()[4:17] + timestampStrJ := walFiles[j].Name()[4:17] + timestampI, errI := time.Parse("20060102_1504", timestampStrI) + timestampJ, errJ := time.Parse("20060102_1504", timestampStrJ) + if errI != nil || errJ != nil { + return false + } + return timestampI.Before(timestampJ) + }) + + for _, file := range walFiles { + filePath := filepath.Join(w.logDir, file.Name()) + + slog.Debug("loading WAL", slog.Any("file", filePath)) + + db, err := sql.Open("sqlite3", filePath) + if err != nil { + return fmt.Errorf("failed to open WAL file %s: %v", file.Name(), err) + } + + rows, err := db.Query("SELECT command FROM wal") + if err != nil { + return fmt.Errorf("failed to query WAL file %s: %v", file.Name(), err) + } + + for rows.Next() { + var command string + if err := rows.Scan(&command); err != nil { + return fmt.Errorf("failed to scan WAL file %s: %v", file.Name(), err) + } + + tokens := strings.Split(command, " ") + if err := f(cmd.DiceDBCmd{ + Cmd: tokens[0], + Args: tokens[1:], + }); err != nil { + return err + } + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to iterate WAL file %s: %v", file.Name(), err) + } + + if err := db.Close(); err != nil { + return fmt.Errorf("failed to close WAL file %s: %v", file.Name(), err) + } + } + + return nil +} diff --git a/internal/wal/wal_test.go b/internal/wal/wal_test.go new file mode 100644 index 000000000..b3ee707e3 --- /dev/null +++ b/internal/wal/wal_test.go @@ -0,0 +1,50 @@ +package wal_test + +import ( + "log/slog" + "testing" + "time" + + "github.com/dicedb/dice/internal/cmd" + "github.com/dicedb/dice/internal/wal" +) + +func BenchmarkLogCommandSQLite(b *testing.B) { + wl, err := wal.NewSQLiteWAL("/tmp/dicedb-lt") + if err != nil { + panic(err) + } + + if err := wl.Init(time.Now()); err != nil { + slog.Error("could not initialize WAL", slog.Any("error", err)) + } else { + go wal.InitBG(wl) + } + + for i := 0; i < b.N; i++ { + wl.LogCommand(&cmd.DiceDBCmd{ + Cmd: "SET", + Args: []string{"key", "value"}, + }) + } +} + +func BenchmarkLogCommandAOF(b *testing.B) { + wl, err := wal.NewAOFWAL("/tmp/dicedb-lt") + if err != nil { + panic(err) + } + + if err := wl.Init(time.Now()); err != nil { + slog.Error("could not initialize WAL", slog.Any("error", err)) + } else { + go wal.InitBG(wl) + } + + for i := 0; i < b.N; i++ { + wl.LogCommand(&cmd.DiceDBCmd{ + Cmd: "SET", + Args: []string{"key", "value"}, + }) + } +} diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 304d3aced..6e6027e25 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -14,6 +14,7 @@ import ( "time" "github.com/dicedb/dice/internal/querymanager" + "github.com/dicedb/dice/internal/wal" "github.com/dicedb/dice/internal/watchmanager" "github.com/dicedb/dice/config" @@ -49,23 +50,24 @@ type BaseWorker struct { responseChan chan *ops.StoreResponse preprocessingChan chan *ops.StoreResponse cmdWatchSubscriptionChan chan watchmanager.WatchSubscription + wl wal.AbstractWAL } func NewWorker(wid string, responseChan, preprocessingChan chan *ops.StoreResponse, cmdWatchSubscriptionChan chan watchmanager.WatchSubscription, ioHandler iohandler.IOHandler, parser requestparser.Parser, - shardManager *shard.ShardManager, gec chan error) *BaseWorker { + shardManager *shard.ShardManager, gec chan error, wl wal.AbstractWAL) *BaseWorker { return &BaseWorker{ - id: wid, - ioHandler: ioHandler, - parser: parser, - shardManager: shardManager, - globalErrorChan: gec, - responseChan: responseChan, - preprocessingChan: preprocessingChan, - cmdWatchSubscriptionChan: cmdWatchSubscriptionChan, - Session: auth.NewSession(), - adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize), + id: wid, + ioHandler: ioHandler, + parser: parser, + shardManager: shardManager, + globalErrorChan: gec, + responseChan: responseChan, + preprocessingChan: preprocessingChan, + Session: auth.NewSession(), + adhocReqChan: make(chan *cmd.DiceDBCmd, config.DiceConfig.Performance.AdhocReqChanBufSize), + wl: wl, } } @@ -423,6 +425,8 @@ func (w *BaseWorker) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCm return err } + w.wl.LogCommand(diceDBCmd) + case MultiShard: err := w.ioHandler.Write(ctx, val.composeResponse(storeOp...)) if err != nil { @@ -430,6 +434,8 @@ func (w *BaseWorker) gather(ctx context.Context, diceDBCmd *cmd.DiceDBCmd, numCm return err } + w.wl.LogCommand(diceDBCmd) + default: slog.Error("Unknown command type", slog.String("workerID", w.id), slog.String("command", diceDBCmd.Cmd), slog.Any("evalResp", storeOp)) err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) diff --git a/main.go b/main.go index 20fbc109b..6913fa75d 100644 --- a/main.go +++ b/main.go @@ -15,9 +15,11 @@ import ( "strings" "sync" "syscall" + "time" "github.com/dicedb/dice/internal/logger" "github.com/dicedb/dice/internal/server/abstractserver" + "github.com/dicedb/dice/internal/wal" "github.com/dicedb/dice/internal/watchmanager" "github.com/dicedb/dice/config" @@ -55,6 +57,11 @@ func init() { flag.BoolVar(&config.EnableProfiling, "enable-profiling", false, "enable profiling and capture critical metrics and traces in .prof files") flag.StringVar(&config.DiceConfig.Logging.LogLevel, "log-level", "info", "log level, values: info, debug") + flag.StringVar(&config.LogDir, "log-dir", "/tmp/dicedb", "log directory path") + + flag.BoolVar(&config.EnableWAL, "enable-wal", false, "enable write-ahead logging") + flag.BoolVar(&config.RestoreFromWAL, "restore-wal", false, "restore the database from the WAL files") + flag.StringVar(&config.WALEngine, "wal-engine", "null", "wal engine to use, values: sqlite, aof") flag.StringVar(&config.RequirePass, "requirepass", config.RequirePass, "enable authentication for the default user") flag.StringVar(&config.CustomConfigFilePath, "o", config.CustomConfigFilePath, "dir path to create the config file") @@ -174,10 +181,51 @@ func main() { var ( queryWatchChan chan dstore.QueryWatchEvent cmdWatchChan chan dstore.CmdWatchEvent - cmdWatchSubscriptionChan = make(chan watchmanager.WatchSubscription) serverErrCh = make(chan error, 2) + cmdWatchSubscriptionChan = make(chan watchmanager.WatchSubscription) + wl wal.AbstractWAL ) + wl, _ = wal.NewNullWAL() + slog.Info("running with", slog.Bool("enable-wal", config.EnableWAL)) + if config.EnableWAL { + if config.WALEngine == "sqlite" { + _wl, err := wal.NewSQLiteWAL(config.LogDir) + if err != nil { + slog.Warn("could not create WAL with", slog.String("wal-engine", config.WALEngine), slog.Any("error", err)) + sigs <- syscall.SIGKILL + return + } + wl = _wl + } else if config.WALEngine == "aof" { + _wl, err := wal.NewAOFWAL(config.LogDir) + if err != nil { + slog.Warn("could not create WAL with", slog.String("wal-engine", config.WALEngine), slog.Any("error", err)) + sigs <- syscall.SIGKILL + return + } + wl = _wl + } else { + slog.Error("unsupported WAL engine", slog.String("engine", config.WALEngine)) + sigs <- syscall.SIGKILL + return + } + + if err := wl.Init(time.Now()); err != nil { + slog.Error("could not initialize WAL", slog.Any("error", err)) + } else { + go wal.InitBG(wl) + } + + slog.Debug("WAL initialization complete") + + if config.RestoreFromWAL { + slog.Info("restoring database from WAL") + wal.ReplayWAL(wl) + slog.Info("database restored from WAL") + } + } + if config.EnableWatch { bufSize := config.DiceConfig.Performance.WatchChanBufSize queryWatchChan = make(chan dstore.QueryWatchEvent, bufSize) @@ -229,11 +277,11 @@ func main() { } workerManager := worker.NewWorkerManager(config.DiceConfig.Performance.MaxClients, shardManager) - respServer := resp.NewServer(shardManager, workerManager, cmdWatchSubscriptionChan, cmdWatchChan, serverErrCh) + respServer := resp.NewServer(shardManager, workerManager, cmdWatchSubscriptionChan, cmdWatchChan, serverErrCh, wl) serverWg.Add(1) go runServer(ctx, &serverWg, respServer, serverErrCh) } else { - asyncServer := server.NewAsyncServer(shardManager, queryWatchChan) + asyncServer := server.NewAsyncServer(shardManager, queryWatchChan, wl) if err := asyncServer.FindPortAndBind(); err != nil { slog.Error("Error finding and binding port", slog.Any("error", err)) sigs <- syscall.SIGKILL @@ -243,14 +291,14 @@ func main() { go runServer(ctx, &serverWg, asyncServer, serverErrCh) if config.EnableHTTP { - httpServer := server.NewHTTPServer(shardManager) + httpServer := server.NewHTTPServer(shardManager, wl) serverWg.Add(1) go runServer(ctx, &serverWg, httpServer, serverErrCh) } } if config.EnableWebsocket { - websocketServer := server.NewWebSocketServer(shardManager, config.WebsocketPort) + websocketServer := server.NewWebSocketServer(shardManager, config.WebsocketPort, wl) serverWg.Add(1) go runServer(ctx, &serverWg, websocketServer, serverErrCh) } @@ -276,6 +324,11 @@ func main() { } close(sigs) + + if config.EnableWAL { + wal.ShutdownBG() + } + cancel() wg.Wait()