Skip to content

Commit 324cfd8

Browse files
committed
feat: Modify user_input to return count as well.
1 parent 570298e commit 324cfd8

File tree

5 files changed

+86
-30
lines changed

5 files changed

+86
-30
lines changed

lua/nvim-surround/cache.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ local M = {}
44

55
---@type { delimiters: string[][]|nil, line_mode: boolean }
66
M.normal = {}
7-
---@type { char: string }
8-
M.delete = {}
7+
---@type { char: string, count: integer }|nil
8+
M.delete = nil
99
---@type { del_char: string, add_delimiters: add_func, line_mode: boolean }
1010
M.change = {}
1111

lua/nvim-surround/init.lua

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ M.insert_surround = function(args)
1818
local config = require("nvim-surround.config")
1919
local buffer = require("nvim-surround.buffer")
2020
local input = require("nvim-surround.input")
21-
local char = input.get_char()
22-
local curpos = buffer.get_curpos()
23-
local delimiters = config.get_delimiters(char, args.line_mode)
21+
local user_input = input.get_char()
22+
if not user_input then
23+
return
24+
end
25+
-- TODO: Handle repeating!
26+
local delimiters = config.get_delimiters(user_input.char, args.line_mode)
2427
if not delimiters then
2528
return
2629
end
2730

31+
local curpos = buffer.get_curpos()
2832
buffer.insert_text(curpos, delimiters[2])
2933
buffer.insert_text(curpos, delimiters[1])
3034
buffer.set_curpos({ curpos[1] + #delimiters[1] - 1, curpos[2] + #delimiters[1][#delimiters[1]] })
@@ -86,12 +90,15 @@ M.visual_surround = function(args)
8690
local config = require("nvim-surround.config")
8791
local buffer = require("nvim-surround.buffer")
8892
local input = require("nvim-surround.input")
89-
local ins_char = input.get_char()
93+
local user_input = input.get_char()
94+
if user_input == nil then
95+
return
96+
end
9097

9198
if vim.fn.visualmode() == "V" then
9299
args.line_mode = true
93100
end
94-
local delimiters = config.get_delimiters(ins_char, args.line_mode)
101+
local delimiters = config.get_delimiters(user_input.char, args.line_mode)
95102
local first_pos, last_pos = buffer.get_mark("<"), buffer.get_mark(">")
96103
if not delimiters or not first_pos or not last_pos then
97104
return
@@ -163,7 +170,7 @@ M.delete_surround = function(args)
163170
-- Call the operatorfunc if it has not been called yet
164171
if not args then
165172
-- Clear the delete cache (since it was user-called)
166-
cache.delete = {}
173+
cache.delete = nil
167174

168175
vim.go.operatorfunc = "v:lua.require'nvim-surround'.delete_callback"
169176
return "g@l"
@@ -301,9 +308,14 @@ M.normal_callback = function(mode)
301308
end
302309
-- Get a character input and the delimiters (if not cached)
303310
if not cache.normal.delimiters then
304-
local char = input.get_char()
311+
local user_input = input.get_char()
312+
if user_input == nil then
313+
M.pending_surround = false
314+
buffer.clear_highlights()
315+
return
316+
end
305317
-- Get the delimiter pair based on the input character
306-
cache.normal.delimiters = config.get_delimiters(char, cache.normal.line_mode)
318+
cache.normal.delimiters = config.get_delimiters(user_input.char, cache.normal.line_mode)
307319
if not cache.normal.delimiters then
308320
M.pending_surround = false
309321
buffer.clear_highlights()
@@ -328,9 +340,11 @@ M.delete_callback = function()
328340
-- Save the current position of the cursor
329341
local curpos = buffer.get_curpos()
330342
-- Get a character input if not cached
331-
cache.delete.char = cache.delete.char or input.get_char()
332-
if not cache.delete.char then
333-
return
343+
if cache.delete == nil then
344+
cache.delete = input.get_char()
345+
if cache.delete == nil then
346+
return
347+
end
334348
end
335349

336350
M.delete_surround({
@@ -348,7 +362,12 @@ M.change_callback = function()
348362
-- Save the current position of the cursor
349363
local curpos = buffer.get_curpos()
350364
if not cache.change.del_char or not cache.change.add_delimiters then
351-
local del_char = config.get_alias(input.get_char())
365+
local user_input = input.get_char()
366+
if user_input == nil then
367+
return
368+
end
369+
370+
local del_char = config.get_alias(user_input.char)
352371
local change = config.get_change(del_char)
353372
local selections = utils.get_nearest_selections(del_char, "change")
354373
if not (del_char and change and selections) then
@@ -366,12 +385,15 @@ M.change_callback = function()
366385
end
367386

368387
-- Get the new surrounding pair, querying the user for more input if no replacement is provided
369-
local ins_char, delimiters
388+
local delimiters
370389
if change and change.replacement then
371390
delimiters = change.replacement()
372391
else
373-
ins_char = input.get_char()
374-
delimiters = config.get_delimiters(ins_char, cache.change.line_mode)
392+
local user_input = input.get_char()
393+
if user_input == nil then
394+
return
395+
end
396+
delimiters = config.get_delimiters(user_input.char, cache.change.line_mode)
375397
end
376398

377399
-- Clear the highlights after getting the replacement surround

lua/nvim-surround/input.lua

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,33 @@ M.replace_termcodes = function(char)
1414
end
1515

1616
-- Gets a character input from the user.
17-
---@return string|nil @The input character, or nil if an escape character is pressed.
17+
---@return {char: string, count: integer}|nil @The input character, or nil if an escape character is pressed.
1818
---@nodiscard
1919
M.get_char = function()
20-
local ok, char = pcall(vim.fn.getcharstr)
21-
-- Return nil if input is cancelled (e.g. <C-c> or <Esc>)
22-
if not ok or char == "\27" then
23-
return nil
24-
end
25-
return M.replace_termcodes(char)
20+
local has_count = false
21+
local count = 0
22+
local char = nil
23+
24+
repeat
25+
local ok, input_char = pcall(vim.fn.getcharstr)
26+
-- Return nil if input is cancelled (e.g. <C-c> or <Esc>)
27+
if not ok or input_char == "\27" then
28+
return nil
29+
end
30+
31+
local digit = tonumber(input_char)
32+
if digit ~= nil then
33+
has_count = true
34+
count = 10 * count + digit
35+
else
36+
char = M.replace_termcodes(input_char)
37+
end
38+
until char ~= nil
39+
40+
return {
41+
count = has_count and count or 1,
42+
char = char,
43+
}
2644
end
2745

2846
-- Gets a string input from the user.

lua/nvim-surround/utils.lua

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,22 @@ local M = {}
77
-- Do nothing.
88
M.NOOP = function() end
99

10+
-- Repeats a delimiter pair n times.
11+
---@param delimiters delimiter_pair The delimiters to be repeated.
12+
---@param n integer The number of times to repeat the delimiters.
13+
---@return delimiter_pair @The repeated delimiters.
14+
---@nodiscard
15+
M.repeat_delimiters = function(delimiters, n)
16+
local acc = { { "" }, { "" } }
17+
for _ = 1, n do
18+
acc[1][#acc[1]] = acc[1][#acc[1]] .. delimiters[1][1]
19+
vim.list_extend(acc[1], delimiters[1], 2)
20+
acc[2][#acc[2]] = acc[2][#acc[2]] .. delimiters[2][1]
21+
vim.list_extend(acc[2], delimiters[2], 2)
22+
end
23+
return acc
24+
end
25+
1026
-- Gets the nearest two selections for the left and right surrounding pair.
1127
---@param char string|nil A character representing what kind of surrounding pair is to be selected.
1228
---@param action "delete"|"change" A string representing what action is being performed.

tests/configuration_spec.lua

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ describe("configuration", function()
3030
it("can define own add mappings", function()
3131
require("nvim-surround").buffer_setup({
3232
surrounds = {
33-
["1"] = { add = { "1", "1" } },
34-
["2"] = { add = { "2", { "2" } } },
35-
["3"] = { add = { { "3" }, "3" } },
33+
["q"] = { add = { "1", "1" } },
34+
["w"] = { add = { "2", { "2" } } },
35+
["e"] = { add = { { "3" }, "3" } },
3636
["f"] = { add = { { "int main() {", " " }, { "", "}" } } },
3737
},
3838
})
@@ -44,11 +44,11 @@ describe("configuration", function()
4444
"interesting stuff",
4545
})
4646
set_curpos({ 1, 1 })
47-
vim.cmd("normal yss1")
47+
vim.cmd("normal yssq")
4848
set_curpos({ 2, 1 })
49-
vim.cmd("normal yss2")
49+
vim.cmd("normal yssw")
5050
set_curpos({ 3, 1 })
51-
vim.cmd("normal yss3")
51+
vim.cmd("normal ysse")
5252
set_curpos({ 4, 1 })
5353
vim.cmd("normal yssf")
5454
check_lines({

0 commit comments

Comments
 (0)