Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lua/nvim-surround/cache.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ local M = {}

---@type { delimiters: string[][]|nil, line_mode: boolean, count: integer }
M.normal = {}
---@type { char: string }
---@type { char: string, count: integer }
M.delete = {}
---@type { del_char: string, add_delimiters: add_func, line_mode: boolean }
---@type { del_char: string, add_delimiters: add_func, line_mode: boolean, count: integer }
M.change = {}

-- Sets the callback function for dot-repeating.
---@param func_name string A string representing the callback function's name.
M.set_callback = function(func_name)
vim.go.operatorfunc = "v:lua.require'nvim-surround.utils'.NOOP"
vim.cmd.normal({ "g@l", bang = true })
vim.cmd.normal({ [1] = "g@l", bang = true })
vim.go.operatorfunc = func_name
end

Expand Down
15 changes: 7 additions & 8 deletions lua/nvim-surround/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,17 @@ end
---@return delimiter_pair|nil @A pair of delimiters for the given input, or nil if not applicable.
---@nodiscard
M.get_delimiters = function(char, line_mode)
local utils = require("nvim-surround.utils")

char = M.get_alias(char)
-- Get the delimiters, using invalid_key_behavior if the add function is undefined for the character
local delimiters = M.get_add(char)(char)
if delimiters == nil then
local raw_delimiters = M.get_add(char)(char)
if raw_delimiters == nil then
return nil
end
local lhs = type(delimiters[1]) == "string" and { delimiters[1] } or delimiters[1]
local rhs = type(delimiters[2]) == "string" and { delimiters[2] } or delimiters[2]
-- These casts are needed because LuaLS doesn't narrow types in ternaries properly
-- https://github.com/LuaLS/lua-language-server/issues/2233
---@cast lhs string[]
---@cast rhs string[]
local delimiters = utils.normalize_delimiters(raw_delimiters)
local lhs = delimiters[1]
local rhs = delimiters[2]

-- Add new lines if the addition is done line-wise
if line_mode then
Expand Down
144 changes: 81 additions & 63 deletions lua/nvim-surround/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ M.delete_surround = function(args)
-- Call the operatorfunc if it has not been called yet
if not args then
-- Clear the delete cache (since it was user-called)
cache.delete = {}
cache.delete = { count = vim.v.count1 }

vim.go.operatorfunc = "v:lua.require'nvim-surround'.delete_callback"
return "g@l"
Expand Down Expand Up @@ -213,7 +213,7 @@ M.change_surround = function(args)
-- Call the operatorfunc if it has not been called yet
if not args.del_char or not args.add_delimiters then
-- Clear the change cache (since it was user-called)
cache.change = { line_mode = args.line_mode }
cache.change = { line_mode = args.line_mode, count = vim.v.count1 }

vim.go.operatorfunc = "v:lua.require'nvim-surround'.change_callback"
return "g@l"
Expand All @@ -222,45 +222,48 @@ M.change_surround = function(args)
buffer.set_curpos(args.curpos)
-- Get the selections to change, as well as the delimiters to replace those selections
local selections = utils.get_nearest_selections(args.del_char, "change")
local delimiters = args.add_delimiters()
if selections and delimiters then
-- Avoid adding any, and remove any existing whitespace after the
-- opening delimiter if only whitespace exists between it and the end
-- of the line. Avoid adding or removing leading whitespace before the
-- closing delimiter if only whitespace exists between it and the
-- beginning of the line.

local space_begin, space_end = buffer.get_line(selections.left.last_pos[1]):find("%s*$")
if space_begin - 1 <= selections.left.last_pos[2] then -- Whitespace is adjacent to opening delimiter
-- Trim trailing whitespace from opening delimiter
delimiters[1][#delimiters[1]] = delimiters[1][#delimiters[1]]:gsub("%s+$", "")
-- Grow selection end to include trailing whitespace, so it gets removed
selections.left.last_pos[2] = space_end
end
local raw_delimiters = args.add_delimiters()
if not (selections and raw_delimiters) then
cache.set_callback("v:lua.require'nvim-surround'.change_callback")
return
end
local delimiters = utils.normalize_delimiters(raw_delimiters)
-- Avoid adding any, and remove any existing whitespace after the
-- opening delimiter if only whitespace exists between it and the end
-- of the line. Avoid adding or removing leading whitespace before the
-- closing delimiter if only whitespace exists between it and the
-- beginning of the line.

local space_begin, space_end = buffer.get_line(selections.left.last_pos[1]):find("%s*$")
if space_begin - 1 <= selections.left.last_pos[2] then -- Whitespace is adjacent to opening delimiter
-- Trim trailing whitespace from opening delimiter
delimiters[1][#delimiters[1]] = delimiters[1][#delimiters[1]]:gsub("%s+$", "")
-- Grow selection end to include trailing whitespace, so it gets removed
selections.left.last_pos[2] = space_end
end

space_begin, space_end = buffer.get_line(selections.right.first_pos[1]):find("^%s*")
if space_end + 1 >= selections.right.first_pos[2] then -- Whitespace is adjacent to closing delimiter
-- Trim leading whitespace from closing delimiter
delimiters[2][1] = delimiters[2][1]:gsub("^%s+", "")
-- Shrink selection beginning to exclude leading whitespace, so it remains unchanged
selections.right.first_pos[2] = space_end + 1
end
space_begin, space_end = buffer.get_line(selections.right.first_pos[1]):find("^%s*")
if space_end + 1 >= selections.right.first_pos[2] then -- Whitespace is adjacent to closing delimiter
-- Trim leading whitespace from closing delimiter
delimiters[2][1] = delimiters[2][1]:gsub("^%s+", "")
-- Shrink selection beginning to exclude leading whitespace, so it remains unchanged
selections.right.first_pos[2] = space_end + 1
end

local sticky_pos = buffer.with_extmark(args.curpos, function()
buffer.change_selection(selections.right, delimiters[2])
buffer.change_selection(selections.left, delimiters[1])
end)
buffer.restore_curpos({
first_pos = selections.left.first_pos,
sticky_pos = sticky_pos,
old_pos = args.curpos,
})
local sticky_pos = buffer.with_extmark(args.curpos, function()
buffer.change_selection(selections.right, delimiters[2])
buffer.change_selection(selections.left, delimiters[1])
end)
buffer.restore_curpos({
first_pos = selections.left.first_pos,
sticky_pos = sticky_pos,
old_pos = args.curpos,
})

if args.line_mode then
local first_line = selections.left.first_pos[1]
local last_line = selections.right.last_pos[1]
config.get_opts().indent_lines(first_line, last_line + #delimiters[1] + #delimiters[2] - 2)
end
if args.line_mode then
local first_line = selections.left.first_pos[1]
local last_line = selections.right.last_pos[1]
config.get_opts().indent_lines(first_line, last_line + #delimiters[1] + #delimiters[2] - 2)
end

cache.set_callback("v:lua.require'nvim-surround'.change_callback")
Expand Down Expand Up @@ -338,18 +341,18 @@ M.delete_callback = function()
local buffer = require("nvim-surround.buffer")
local cache = require("nvim-surround.cache")
local input = require("nvim-surround.input")
-- Save the current position of the cursor
local curpos = buffer.get_curpos()
-- Get a character input if not cached
cache.delete.char = cache.delete.char or input.get_char()
if not cache.delete.char then
return
end

M.delete_surround({
del_char = cache.delete.char,
curpos = curpos,
})
for _ = 1, cache.delete.count do
M.delete_surround({
del_char = cache.delete.char,
curpos = buffer.get_curpos(),
})
end
end

M.change_callback = function()
Expand All @@ -358,13 +361,18 @@ M.change_callback = function()
local cache = require("nvim-surround.cache")
local input = require("nvim-surround.input")
local utils = require("nvim-surround.utils")
-- Save the current position of the cursor
local curpos = buffer.get_curpos()
if not cache.change.del_char or not cache.change.add_delimiters then
local del_char = config.get_alias(input.get_char())
local change = config.get_change(del_char)

local del_char = cache.change.del_char or config.get_alias(input.get_char())
local change = config.get_change(del_char)
if not (del_char and change) then
return
end

-- To handle number prefixing properly, we just run the replacement algorithm multiple times
for _ = 1, cache.change.count do
-- If at any point we are unable to find a surrounding pair to change, early exit
local selections = utils.get_nearest_selections(del_char, "change")
if not (del_char and change and selections) then
if not selections then
return
end

Expand All @@ -378,13 +386,17 @@ M.change_callback = function()
end
end

-- Get the new surrounding pair, querying the user for more input if no replacement is provided
local ins_char, delimiters
if change and change.replacement then
delimiters = change.replacement()
else
ins_char = input.get_char()
delimiters = config.get_delimiters(ins_char, cache.change.line_mode)
-- Get the new surrounding delimiter pair, prioritizing any delimiters in the cache
-- NB: This must occur between drawing the highlights and clearing them, so the selections are properly
-- highlighted if the user is providing (blocking) input
local delimiters = cache.change.add_delimiters and cache.change.add_delimiters()
if not delimiters then
if change and change.replacement then
delimiters = delimiters or change.replacement()
else
local ins_char = input.get_char()
delimiters = delimiters or config.get_delimiters(ins_char, cache.change.line_mode)
end
end

-- Clear the highlights after getting the replacement surround
Expand All @@ -393,18 +405,24 @@ M.change_callback = function()
return
end

local add_delimiters = function()
return delimiters
end
-- Set the cache
cache.change = {
del_char = del_char,
add_delimiters = function()
return delimiters
end,
add_delimiters = add_delimiters,
line_mode = cache.change.line_mode,
count = cache.change.count,
}
M.change_surround({
del_char = del_char,
add_delimiters = add_delimiters,
line_mode = cache.change.line_mode,
count = cache.change.count,
curpos = buffer.get_curpos(),
})
end
local args = vim.deepcopy(cache.change)
args.curpos = curpos
M.change_surround(args) ---@diagnostic disable-line: param-type-mismatch
end

return M
13 changes: 13 additions & 0 deletions lua/nvim-surround/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ M.repeat_delimiters = function(delimiters, n)
return acc
end

-- Normalizes a pair of delimiters to use a string[] for both the left and right delimiters
---@param raw_delimiters (string|string[])[] The delimiters to be repeated.
---@return delimiter_pair @The normalized delimiters.
---@nodiscard
M.normalize_delimiters = function(raw_delimiters)
local lhs = raw_delimiters[1]
local rhs = raw_delimiters[2]
return {
type(lhs) == "string" and { lhs } or lhs,
type(rhs) == "string" and { rhs } or rhs,
}
end

-- Gets the nearest two selections for the left and right surrounding pair.
---@param char string|nil A character representing what kind of surrounding pair is to be selected.
---@param action "delete"|"change" A string representing what action is being performed.
Expand Down
38 changes: 38 additions & 0 deletions tests/basics_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -866,4 +866,42 @@ describe("nvim-surround", function()
"a sli<<<ghtly longer l>>>ine",
})
end)

it("can handle number prefixing for deleting surrounds", function()
set_lines({ "some {{{{more placeholder}}}} text" })
set_curpos({ 1, 6 })
vim.cmd("normal 2dsB")
check_lines({ "some {{more placeholder}} text" })
vim.cmd("normal .")
check_lines({ "some more placeholder text" })

set_lines({ "((foo) bar (baz))" })
set_curpos({ 1, 9 })
vim.cmd("normal 2dsb")
check_lines({ "foo bar (baz)" })

set_lines({ "some ((more placeholder)) text" })
set_curpos({ 1, 6 })
vim.cmd("normal 3dsb")
check_lines({ "some more placeholder text" })
end)

it("can handle number prefixing for changing surrounds", function()
set_lines({ "some {{{{more placeholder}}}} text" })
set_curpos({ 1, 11 })
vim.cmd("normal 2csBa")
check_lines({ "some {{<<more placeholder>>}} text" })
vim.cmd("normal .")
check_lines({ "some <<<<more placeholder>>>> text" })

set_lines({ "((foo) bar (baz))" })
set_curpos({ 1, 9 })
vim.cmd("normal 2csbB")
check_lines({ "{{foo} bar (baz)}" })

set_lines({ "some ((more placeholder)) text" })
set_curpos({ 1, 6 })
vim.cmd("normal 3csbr")
check_lines({ "some [[more placeholder]] text" })
end)
end)
29 changes: 29 additions & 0 deletions tests/configuration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -592,4 +592,33 @@ describe("configuration", function()
assert.are.same(get_curpos(), { 1, 1 })
check_lines({ "print('foo')" })
end)

it("will handle number prefixing as if the user used dot-repeat", function()
require("nvim-surround").setup({ move_cursor = "sticky" })
set_lines({ "foo bar baz" })
set_curpos({ 1, 5 })
vim.cmd("normal 3ysiwb")
check_lines({ "foo (((bar))) baz" })
check_curpos({ 1, 8 })
vim.cmd("normal 2ySSa")
check_lines({
"<",
"<",
"foo (((bar))) baz",
">",
">",
})

set_lines({ "((foo) bar (baz))" })
set_curpos({ 1, 9 })
vim.cmd("normal 2dsb")
check_lines({ "(foo) bar baz" })
check_curpos({ 1, 8 })

set_lines({ "((foo) bar (baz))" })
set_curpos({ 1, 9 })
vim.cmd("normal 2csbr")
check_lines({ "[(foo) bar [baz]]" })
check_curpos({ 1, 9 })
end)
end)
Loading