Skip to content

Commit

Permalink
feat: Modify user_input to return count as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
kylechui committed Nov 26, 2024
1 parent 570298e commit 324cfd8
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 30 deletions.
4 changes: 2 additions & 2 deletions lua/nvim-surround/cache.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ local M = {}

---@type { delimiters: string[][]|nil, line_mode: boolean }
M.normal = {}
---@type { char: string }
M.delete = {}
---@type { char: string, count: integer }|nil
M.delete = nil
---@type { del_char: string, add_delimiters: add_func, line_mode: boolean }
M.change = {}

Expand Down
52 changes: 37 additions & 15 deletions lua/nvim-surround/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ M.insert_surround = function(args)
local config = require("nvim-surround.config")
local buffer = require("nvim-surround.buffer")
local input = require("nvim-surround.input")
local char = input.get_char()
local curpos = buffer.get_curpos()
local delimiters = config.get_delimiters(char, args.line_mode)
local user_input = input.get_char()
if not user_input then
return
end
-- TODO: Handle repeating!
local delimiters = config.get_delimiters(user_input.char, args.line_mode)
if not delimiters then
return
end

local curpos = buffer.get_curpos()
buffer.insert_text(curpos, delimiters[2])
buffer.insert_text(curpos, delimiters[1])
buffer.set_curpos({ curpos[1] + #delimiters[1] - 1, curpos[2] + #delimiters[1][#delimiters[1]] })
Expand Down Expand Up @@ -86,12 +90,15 @@ M.visual_surround = function(args)
local config = require("nvim-surround.config")
local buffer = require("nvim-surround.buffer")
local input = require("nvim-surround.input")
local ins_char = input.get_char()
local user_input = input.get_char()
if user_input == nil then
return
end

if vim.fn.visualmode() == "V" then
args.line_mode = true
end
local delimiters = config.get_delimiters(ins_char, args.line_mode)
local delimiters = config.get_delimiters(user_input.char, args.line_mode)
local first_pos, last_pos = buffer.get_mark("<"), buffer.get_mark(">")
if not delimiters or not first_pos or not last_pos then
return
Expand Down Expand Up @@ -163,7 +170,7 @@ M.delete_surround = function(args)
-- Call the operatorfunc if it has not been called yet
if not args then
-- Clear the delete cache (since it was user-called)
cache.delete = {}
cache.delete = nil

vim.go.operatorfunc = "v:lua.require'nvim-surround'.delete_callback"
return "g@l"
Expand Down Expand Up @@ -301,9 +308,14 @@ M.normal_callback = function(mode)
end
-- Get a character input and the delimiters (if not cached)
if not cache.normal.delimiters then
local char = input.get_char()
local user_input = input.get_char()
if user_input == nil then
M.pending_surround = false
buffer.clear_highlights()
return
end
-- Get the delimiter pair based on the input character
cache.normal.delimiters = config.get_delimiters(char, cache.normal.line_mode)
cache.normal.delimiters = config.get_delimiters(user_input.char, cache.normal.line_mode)
if not cache.normal.delimiters then
M.pending_surround = false
buffer.clear_highlights()
Expand All @@ -328,9 +340,11 @@ M.delete_callback = function()
-- Save the current position of the cursor
local curpos = buffer.get_curpos()
-- Get a character input if not cached
cache.delete.char = cache.delete.char or input.get_char()
if not cache.delete.char then
return
if cache.delete == nil then
cache.delete = input.get_char()
if cache.delete == nil then
return
end
end

M.delete_surround({
Expand All @@ -348,7 +362,12 @@ M.change_callback = function()
-- Save the current position of the cursor
local curpos = buffer.get_curpos()
if not cache.change.del_char or not cache.change.add_delimiters then
local del_char = config.get_alias(input.get_char())
local user_input = input.get_char()
if user_input == nil then
return
end

local del_char = config.get_alias(user_input.char)
local change = config.get_change(del_char)
local selections = utils.get_nearest_selections(del_char, "change")
if not (del_char and change and selections) then
Expand All @@ -366,12 +385,15 @@ M.change_callback = function()
end

-- Get the new surrounding pair, querying the user for more input if no replacement is provided
local ins_char, delimiters
local delimiters
if change and change.replacement then
delimiters = change.replacement()
else
ins_char = input.get_char()
delimiters = config.get_delimiters(ins_char, cache.change.line_mode)
local user_input = input.get_char()
if user_input == nil then
return
end
delimiters = config.get_delimiters(user_input.char, cache.change.line_mode)
end

-- Clear the highlights after getting the replacement surround
Expand Down
32 changes: 25 additions & 7 deletions lua/nvim-surround/input.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,33 @@ M.replace_termcodes = function(char)
end

-- Gets a character input from the user.
---@return string|nil @The input character, or nil if an escape character is pressed.
---@return {char: string, count: integer}|nil @The input character, or nil if an escape character is pressed.
---@nodiscard
M.get_char = function()
local ok, char = pcall(vim.fn.getcharstr)
-- Return nil if input is cancelled (e.g. <C-c> or <Esc>)
if not ok or char == "\27" then
return nil
end
return M.replace_termcodes(char)
local has_count = false
local count = 0
local char = nil

repeat
local ok, input_char = pcall(vim.fn.getcharstr)
-- Return nil if input is cancelled (e.g. <C-c> or <Esc>)
if not ok or input_char == "\27" then
return nil
end

local digit = tonumber(input_char)
if digit ~= nil then
has_count = true
count = 10 * count + digit
else
char = M.replace_termcodes(input_char)
end
until char ~= nil

return {
count = has_count and count or 1,
char = char,
}
end

-- Gets a string input from the user.
Expand Down
16 changes: 16 additions & 0 deletions lua/nvim-surround/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@ local M = {}
-- Do nothing.
M.NOOP = function() end

-- Repeats a delimiter pair n times.
---@param delimiters delimiter_pair The delimiters to be repeated.
---@param n integer The number of times to repeat the delimiters.
---@return delimiter_pair @The repeated delimiters.
---@nodiscard
M.repeat_delimiters = function(delimiters, n)
local acc = { { "" }, { "" } }
for _ = 1, n do
acc[1][#acc[1]] = acc[1][#acc[1]] .. delimiters[1][1]
vim.list_extend(acc[1], delimiters[1], 2)
acc[2][#acc[2]] = acc[2][#acc[2]] .. delimiters[2][1]
vim.list_extend(acc[2], delimiters[2], 2)
end
return acc
end

-- Gets the nearest two selections for the left and right surrounding pair.
---@param char string|nil A character representing what kind of surrounding pair is to be selected.
---@param action "delete"|"change" A string representing what action is being performed.
Expand Down
12 changes: 6 additions & 6 deletions tests/configuration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ describe("configuration", function()
it("can define own add mappings", function()
require("nvim-surround").buffer_setup({
surrounds = {
["1"] = { add = { "1", "1" } },
["2"] = { add = { "2", { "2" } } },
["3"] = { add = { { "3" }, "3" } },
["q"] = { add = { "1", "1" } },
["w"] = { add = { "2", { "2" } } },
["e"] = { add = { { "3" }, "3" } },
["f"] = { add = { { "int main() {", " " }, { "", "}" } } },
},
})
Expand All @@ -44,11 +44,11 @@ describe("configuration", function()
"interesting stuff",
})
set_curpos({ 1, 1 })
vim.cmd("normal yss1")
vim.cmd("normal yssq")
set_curpos({ 2, 1 })
vim.cmd("normal yss2")
vim.cmd("normal yssw")
set_curpos({ 3, 1 })
vim.cmd("normal yss3")
vim.cmd("normal ysse")
set_curpos({ 4, 1 })
vim.cmd("normal yssf")
check_lines({
Expand Down

0 comments on commit 324cfd8

Please sign in to comment.