Skip to content

Commit 53a041f

Browse files
kylechuiphgz
authored andcommitted
feat(config): Automatically wrap add key in a table. (kylechui#342)
Inspired by the question in kylechui#341. Makes the UX more uniform between the "simple" version of the `add` key and more advanced callback functions.
1 parent a3e8033 commit 53a041f

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

lua/nvim-surround/annotations.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
---@alias position integer[] A 1-indexed position in the buffer
77
---@alias delimiter string[] The text representation of a delimiter
88
---@alias delimiter_pair delimiter[] A pair of delimiters
9-
---@alias add_func fun(char: string|nil): delimiter_pair|nil
9+
---@alias add_func fun(char: string|nil): delimiter_pair|string[]|nil
1010
---@alias find_func fun(char: string|nil): selection|nil
1111
---@alias delete_func fun(char: string|nil): selections|nil
1212
---@alias change_table { target: delete_func, replacement: add_func|nil }

lua/nvim-surround/config.lua

+14-10
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,18 @@ M.get_delimiters = function(char, line_mode)
362362
char = M.get_alias(char)
363363
-- Get the delimiters, using invalid_key_behavior if the add function is undefined for the character
364364
local delimiters = M.get_add(char)(char)
365-
-- Add new lines if the addition is done line-wise
366-
if delimiters and line_mode then
367-
local lhs = delimiters[1]
368-
local rhs = delimiters[2]
365+
if delimiters == nil then
366+
return nil
367+
end
368+
local lhs = type(delimiters[1]) == "string" and { delimiters[1] } or delimiters[1]
369+
local rhs = type(delimiters[2]) == "string" and { delimiters[2] } or delimiters[2]
370+
-- These casts are needed because LuaLS doesn't narrow types in ternaries properly
371+
-- https://github.com/LuaLS/lua-language-server/issues/2233
372+
---@cast lhs string[]
373+
---@cast rhs string[]
369374

375+
-- Add new lines if the addition is done line-wise
376+
if line_mode then
370377
-- Trim whitespace after the leading delimiter and before the trailing delimiter
371378
lhs[#lhs] = lhs[#lhs]:gsub("%s+$", "")
372379
-- Take into account the possibility that there is no rhs delimiter
@@ -376,7 +383,7 @@ M.get_delimiters = function(char, line_mode)
376383
table.insert(lhs, "")
377384
end
378385

379-
return delimiters
386+
return { lhs, rhs }
380387
end
381388

382389
-- Returns the add key for the surround associated with a given character, if one exists.
@@ -440,12 +447,9 @@ M.translate_add = function(user_add)
440447
if type(user_add) ~= "table" then
441448
return user_add
442449
end
443-
-- If the input is given as a pair of strings, or pair of string lists, wrap it in a function
450+
444451
return function()
445-
return {
446-
functional.to_list(user_add[1]),
447-
functional.to_list(user_add[2]),
448-
}
452+
return user_add
449453
end
450454
end
451455

tests/configuration_spec.lua

+20
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,26 @@ describe("configuration", function()
120120
check_lines({ "hey! hello world" })
121121
end)
122122

123+
it("can use 'syntactic sugar' for add functions", function()
124+
require("nvim-surround").buffer_setup({
125+
surrounds = {
126+
["("] = {
127+
add = function()
128+
return { "<<", ">>" }
129+
end,
130+
},
131+
},
132+
})
133+
134+
set_lines({
135+
"hello world",
136+
})
137+
vim.cmd("normal yss(")
138+
check_lines({
139+
"<<hello world>>",
140+
})
141+
end)
142+
123143
it("can disable surrounds", function()
124144
require("nvim-surround").buffer_setup({
125145
surrounds = {

0 commit comments

Comments
 (0)