From acfe7e789c8788fe875120d6b5e5dc4166a57c62 Mon Sep 17 00:00:00 2001 From: Kyle Chui Date: Mon, 24 Jun 2024 09:11:06 -0700 Subject: [PATCH] feat(config): Automatically wrap `add` key in a table. Inspired by the question in #341. Makes the UX more uniform between the "simple" version of the `add` key and more advanced callback functions. --- lua/nvim-surround/annotations.lua | 2 +- lua/nvim-surround/config.lua | 24 ++++++++++++++---------- tests/configuration_spec.lua | 20 ++++++++++++++++++++ 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/lua/nvim-surround/annotations.lua b/lua/nvim-surround/annotations.lua index f8a0de0..6149b92 100644 --- a/lua/nvim-surround/annotations.lua +++ b/lua/nvim-surround/annotations.lua @@ -6,7 +6,7 @@ ---@alias position integer[] A 1-indexed position in the buffer ---@alias delimiter string[] The text representation of a delimiter ---@alias delimiter_pair delimiter[] A pair of delimiters ----@alias add_func fun(char: string|nil): delimiter_pair|nil +---@alias add_func fun(char: string|nil): delimiter_pair|string[]|nil ---@alias find_func fun(char: string|nil): selection|nil ---@alias delete_func fun(char: string|nil): selections|nil ---@alias change_table { target: delete_func, replacement: add_func|nil } diff --git a/lua/nvim-surround/config.lua b/lua/nvim-surround/config.lua index a141c7d..487409c 100644 --- a/lua/nvim-surround/config.lua +++ b/lua/nvim-surround/config.lua @@ -360,11 +360,18 @@ M.get_delimiters = function(char, line_mode) char = M.get_alias(char) -- Get the delimiters, using invalid_key_behavior if the add function is undefined for the character local delimiters = M.get_add(char)(char) - -- Add new lines if the addition is done line-wise - if delimiters and line_mode then - local lhs = delimiters[1] - local rhs = delimiters[2] + if delimiters == nil then + return nil + end + local lhs = type(delimiters[1]) == "string" and { delimiters[1] } or delimiters[1] + local rhs = type(delimiters[2]) == "string" and { delimiters[2] } or delimiters[2] + -- These casts are needed because LuaLS doesn't narrow types in ternaries properly + -- https://github.com/LuaLS/lua-language-server/issues/2233 + ---@cast lhs string[] + ---@cast rhs string[] + -- Add new lines if the addition is done line-wise + if line_mode then -- Trim whitespace after the leading delimiter and before the trailing delimiter lhs[#lhs] = lhs[#lhs]:gsub("%s+$", "") rhs[1] = rhs[1]:gsub("^%s+", "") @@ -373,7 +380,7 @@ M.get_delimiters = function(char, line_mode) table.insert(lhs, "") end - return delimiters + return { lhs, rhs } end -- Returns the add key for the surround associated with a given character, if one exists. @@ -437,12 +444,9 @@ M.translate_add = function(user_add) if type(user_add) ~= "table" then return user_add end - -- If the input is given as a pair of strings, or pair of string lists, wrap it in a function + return function() - return { - functional.to_list(user_add[1]), - functional.to_list(user_add[2]), - } + return user_add end end diff --git a/tests/configuration_spec.lua b/tests/configuration_spec.lua index 31b256f..37d6d17 100644 --- a/tests/configuration_spec.lua +++ b/tests/configuration_spec.lua @@ -120,6 +120,26 @@ describe("configuration", function() check_lines({ "hey! hello world" }) end) + it("can use 'syntactic sugar' for add functions", function() + require("nvim-surround").buffer_setup({ + surrounds = { + ["("] = { + add = function() + return { "<<", ">>" } + end, + }, + }, + }) + + set_lines({ + "hello world", + }) + vim.cmd("normal yss(") + check_lines({ + "<>", + }) + end) + it("can disable surrounds", function() require("nvim-surround").buffer_setup({ surrounds = {