Skip to content

Commit af584b9

Browse files
committed
fix: Normalize delimiters in change operations.
1 parent 87725ab commit af584b9

File tree

3 files changed

+68
-50
lines changed

3 files changed

+68
-50
lines changed

lua/nvim-surround/config.lua

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -355,18 +355,17 @@ end
355355
---@return delimiter_pair|nil @A pair of delimiters for the given input, or nil if not applicable.
356356
---@nodiscard
357357
M.get_delimiters = function(char, line_mode)
358+
local utils = require("nvim-surround.utils")
359+
358360
char = M.get_alias(char)
359361
-- Get the delimiters, using invalid_key_behavior if the add function is undefined for the character
360-
local delimiters = M.get_add(char)(char)
361-
if delimiters == nil then
362+
local raw_delimiters = M.get_add(char)(char)
363+
if raw_delimiters == nil then
362364
return nil
363365
end
364-
local lhs = type(delimiters[1]) == "string" and { delimiters[1] } or delimiters[1]
365-
local rhs = type(delimiters[2]) == "string" and { delimiters[2] } or delimiters[2]
366-
-- These casts are needed because LuaLS doesn't narrow types in ternaries properly
367-
-- https://github.com/LuaLS/lua-language-server/issues/2233
368-
---@cast lhs string[]
369-
---@cast rhs string[]
366+
local delimiters = utils.normalize_delimiters(raw_delimiters)
367+
local lhs = delimiters[1]
368+
local rhs = delimiters[2]
370369

371370
-- Add new lines if the addition is done line-wise
372371
if line_mode then

lua/nvim-surround/init.lua

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -222,45 +222,48 @@ M.change_surround = function(args)
222222
buffer.set_curpos(args.curpos)
223223
-- Get the selections to change, as well as the delimiters to replace those selections
224224
local selections = utils.get_nearest_selections(args.del_char, "change")
225-
local delimiters = args.add_delimiters()
226-
if selections and delimiters then
227-
-- Avoid adding any, and remove any existing whitespace after the
228-
-- opening delimiter if only whitespace exists between it and the end
229-
-- of the line. Avoid adding or removing leading whitespace before the
230-
-- closing delimiter if only whitespace exists between it and the
231-
-- beginning of the line.
232-
233-
local space_begin, space_end = buffer.get_line(selections.left.last_pos[1]):find("%s*$")
234-
if space_begin - 1 <= selections.left.last_pos[2] then -- Whitespace is adjacent to opening delimiter
235-
-- Trim trailing whitespace from opening delimiter
236-
delimiters[1][#delimiters[1]] = delimiters[1][#delimiters[1]]:gsub("%s+$", "")
237-
-- Grow selection end to include trailing whitespace, so it gets removed
238-
selections.left.last_pos[2] = space_end
239-
end
225+
local raw_delimiters = args.add_delimiters()
226+
if not (selections and raw_delimiters) then
227+
cache.set_callback("v:lua.require'nvim-surround'.change_callback")
228+
return
229+
end
230+
local delimiters = utils.normalize_delimiters(raw_delimiters)
231+
-- Avoid adding any, and remove any existing whitespace after the
232+
-- opening delimiter if only whitespace exists between it and the end
233+
-- of the line. Avoid adding or removing leading whitespace before the
234+
-- closing delimiter if only whitespace exists between it and the
235+
-- beginning of the line.
236+
237+
local space_begin, space_end = buffer.get_line(selections.left.last_pos[1]):find("%s*$")
238+
if space_begin - 1 <= selections.left.last_pos[2] then -- Whitespace is adjacent to opening delimiter
239+
-- Trim trailing whitespace from opening delimiter
240+
delimiters[1][#delimiters[1]] = delimiters[1][#delimiters[1]]:gsub("%s+$", "")
241+
-- Grow selection end to include trailing whitespace, so it gets removed
242+
selections.left.last_pos[2] = space_end
243+
end
240244

241-
space_begin, space_end = buffer.get_line(selections.right.first_pos[1]):find("^%s*")
242-
if space_end + 1 >= selections.right.first_pos[2] then -- Whitespace is adjacent to closing delimiter
243-
-- Trim leading whitespace from closing delimiter
244-
delimiters[2][1] = delimiters[2][1]:gsub("^%s+", "")
245-
-- Shrink selection beginning to exclude leading whitespace, so it remains unchanged
246-
selections.right.first_pos[2] = space_end + 1
247-
end
245+
space_begin, space_end = buffer.get_line(selections.right.first_pos[1]):find("^%s*")
246+
if space_end + 1 >= selections.right.first_pos[2] then -- Whitespace is adjacent to closing delimiter
247+
-- Trim leading whitespace from closing delimiter
248+
delimiters[2][1] = delimiters[2][1]:gsub("^%s+", "")
249+
-- Shrink selection beginning to exclude leading whitespace, so it remains unchanged
250+
selections.right.first_pos[2] = space_end + 1
251+
end
248252

249-
local sticky_pos = buffer.with_extmark(args.curpos, function()
250-
buffer.change_selection(selections.right, delimiters[2])
251-
buffer.change_selection(selections.left, delimiters[1])
252-
end)
253-
buffer.restore_curpos({
254-
first_pos = selections.left.first_pos,
255-
sticky_pos = sticky_pos,
256-
old_pos = args.curpos,
257-
})
253+
local sticky_pos = buffer.with_extmark(args.curpos, function()
254+
buffer.change_selection(selections.right, delimiters[2])
255+
buffer.change_selection(selections.left, delimiters[1])
256+
end)
257+
buffer.restore_curpos({
258+
first_pos = selections.left.first_pos,
259+
sticky_pos = sticky_pos,
260+
old_pos = args.curpos,
261+
})
258262

259-
if args.line_mode then
260-
local first_line = selections.left.first_pos[1]
261-
local last_line = selections.right.last_pos[1]
262-
config.get_opts().indent_lines(first_line, last_line + #delimiters[1] + #delimiters[2] - 2)
263-
end
263+
if args.line_mode then
264+
local first_line = selections.left.first_pos[1]
265+
local last_line = selections.right.last_pos[1]
266+
config.get_opts().indent_lines(first_line, last_line + #delimiters[1] + #delimiters[2] - 2)
264267
end
265268

266269
cache.set_callback("v:lua.require'nvim-surround'.change_callback")
@@ -365,7 +368,9 @@ M.change_callback = function()
365368
return
366369
end
367370

371+
-- To handle number prefixing properly, we just run the replacement algorithm multiple times
368372
for _ = 1, cache.change.count do
373+
-- If at any point we are unable to find a surrounding pair to change, early exit
369374
local selections = utils.get_nearest_selections(del_char, "change")
370375
if not selections then
371376
return
@@ -382,6 +387,8 @@ M.change_callback = function()
382387
end
383388

384389
-- Get the new surrounding delimiter pair, prioritizing any delimiters in the cache
390+
-- NB: This must occur between drawing the highlights and clearing them, so the selections are properly
391+
-- highlighted if the user is providing (blocking) input
385392
local delimiters = cache.change.add_delimiters and cache.change.add_delimiters()
386393
if not delimiters then
387394
if change and change.replacement then
@@ -398,20 +405,19 @@ M.change_callback = function()
398405
return
399406
end
400407

408+
local add_delimiters = function()
409+
return delimiters
410+
end
401411
-- Set the cache
402412
cache.change = {
403413
del_char = del_char,
404-
add_delimiters = function()
405-
return delimiters
406-
end,
414+
add_delimiters = add_delimiters,
407415
line_mode = cache.change.line_mode,
408416
count = cache.change.count,
409417
}
410418
M.change_surround({
411419
del_char = del_char,
412-
add_delimiters = function()
413-
return delimiters
414-
end,
420+
add_delimiters = add_delimiters,
415421
line_mode = cache.change.line_mode,
416422
count = cache.change.count,
417423
curpos = buffer.get_curpos(),

lua/nvim-surround/utils.lua

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ M.repeat_delimiters = function(delimiters, n)
2323
return acc
2424
end
2525

26+
-- Normalizes a pair of delimiters to use a string[] for both the left and right delimiters
27+
---@param raw_delimiters (string|string[])[] The delimiters to be repeated.
28+
---@return delimiter_pair @The normalized delimiters.
29+
---@nodiscard
30+
M.normalize_delimiters = function(raw_delimiters)
31+
local lhs = raw_delimiters[1]
32+
local rhs = raw_delimiters[2]
33+
return {
34+
type(lhs) == "string" and { lhs } or lhs,
35+
type(rhs) == "string" and { rhs } or rhs,
36+
}
37+
end
38+
2639
-- Gets the nearest two selections for the left and right surrounding pair.
2740
---@param char string|nil A character representing what kind of surrounding pair is to be selected.
2841
---@param action "delete"|"change" A string representing what action is being performed.

0 commit comments

Comments
 (0)