Skip to content

Commit 896b0a6

Browse files
committed
fix: Use native vim.treesitter module when available.
1 parent 8dd9150 commit 896b0a6

File tree

2 files changed

+62
-69
lines changed

2 files changed

+62
-69
lines changed

lua/nvim-surround/queries.lua

Lines changed: 14 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,22 @@
1-
local utils = require("nvim-surround.utils")
2-
local ts_query = require("nvim-treesitter.query")
3-
local ts_utils = require("nvim-treesitter.ts_utils")
4-
local ts_parsers = require("nvim-treesitter.parsers")
5-
61
local M = {}
72

83
-- Retrieves the node that corresponds exactly to a given selection.
94
---@param selection selection The given selection.
10-
---@return _ @The corresponding node.
5+
---@return TSNode|nil @The corresponding node.
116
---@nodiscard
127
M.get_node = function(selection)
13-
-- Convert the selection into a list
14-
local range = {
15-
selection.first_pos[1],
16-
selection.first_pos[2],
17-
selection.last_pos[1],
18-
selection.last_pos[2],
19-
}
8+
local treesitter = require("nvim-surround.treesitter")
209

21-
-- Get the root node of the current tree
22-
local lang_tree = ts_parsers.get_parser(0)
23-
local tree = lang_tree:trees()[1]
24-
local root = tree:root()
10+
local root = treesitter.get_root()
11+
if root == nil then
12+
return nil
13+
end
2514
-- DFS through the tree and find all nodes that have the given type
2615
local stack = { root }
2716
while #stack > 0 do
2817
local cur = stack[#stack]
2918
-- If the current node's range is equal to the desired selection, return the node
30-
if vim.deep_equal(range, { ts_utils.get_vim_range({ cur:range() }) }) then
19+
if vim.deep_equal(selection, treesitter.get_node_selection(cur)) then
3120
return cur
3221
end
3322
-- Pop off of the stack
@@ -40,48 +29,24 @@ M.get_node = function(selection)
4029
return nil
4130
end
4231

43-
-- Filters an existing parent selection down to a capture.
44-
---@param sexpr string The given S-expression containing the capture.
45-
---@param capture string The name of the capture to be returned.
46-
---@param parent_selection selection The parent selection to be filtered down.
47-
M.filter_selection = function(sexpr, capture, parent_selection)
48-
local parent_node = M.get_node(parent_selection)
49-
50-
local range = { ts_utils.get_vim_range({ parent_node:range() }) }
51-
local lang_tree = ts_parsers.get_parser(0)
52-
local ok, parsed_query = pcall(function()
53-
return vim.treesitter.query.parse and vim.treesitter.query.parse(lang_tree:lang(), sexpr)
54-
or vim.treesitter.parse_query(lang_tree:lang(), sexpr)
55-
end)
56-
if not ok or not parent_node then
57-
return {}
58-
end
59-
60-
for id, node in parsed_query:iter_captures(parent_node, 0, 0, -1) do
61-
local name = parsed_query.captures[id]
62-
if name == capture then
63-
range = { ts_utils.get_vim_range({ node:range() }) }
64-
return {
65-
first_pos = { range[1], range[2] },
66-
last_pos = { range[3], range[4] },
67-
}
68-
end
69-
end
70-
return nil
71-
end
72-
7332
-- Finds the nearest selection of a given query capture and its source.
7433
---@param capture string The capture to be retrieved.
7534
---@param type string The type of query to get the capture from.
7635
---@return selection|nil @The selection of the capture.
7736
---@nodiscard
7837
M.get_selection = function(capture, type)
38+
local utils = require("nvim-surround.utils")
39+
local treesitter = require("nvim-surround.treesitter")
40+
local ts_query = require("nvim-treesitter.query")
41+
7942
-- Get a table of all nodes that match the query
8043
local table_list = ts_query.get_capture_matches_recursively(0, capture, type)
8144
-- Convert the list of nodes into a list of selections
8245
local selections_list = {}
8346
for _, tab in ipairs(table_list) do
84-
local range = { ts_utils.get_vim_range({ tab.node:range() }) }
47+
local selection = treesitter.get_node_selection(tab.node)
48+
49+
local range = { selection.first_pos[1], selection.first_pos[2], selection.last_pos[1], selection.last_pos[2] }
8550
selections_list[#selections_list + 1] = {
8651
left = {
8752
first_pos = { range[1], range[2] },

lua/nvim-surround/treesitter.lua

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,63 @@ local function is_any_of(target, types)
1414
return false
1515
end
1616

17+
-- Gets the root node for the buffer.
18+
---@return TSNode|nil @The root node for the buffer.
19+
---@nodiscard
20+
M.get_root = function()
21+
if vim.treesitter then
22+
return vim.treesitter.get_node():tree():root()
23+
end
24+
25+
local ok, ts_utils = pcall(require, "nvim-treesitter.ts_utils")
26+
if not ok then
27+
return nil
28+
end
29+
30+
local node = ts_utils.get_node_at_cursor()
31+
if node == nil then
32+
return nil
33+
end
34+
return ts_utils.get_root_for_node(node)
35+
end
36+
37+
-- Gets the selection that a TreeSitter node spans.
38+
---@param node TSNode The given TreeSitter node.
39+
---@return selection @The span of the input node.
40+
---@nodiscard
41+
M.get_node_selection = function(node)
42+
---@type integer, integer, integer, integer
43+
local srow, scol, erow, ecol = node:range()
44+
srow = srow + 1
45+
scol = scol + 1
46+
erow = erow + 1
47+
48+
if ecol == 0 then
49+
-- Use the value of the last col of the previous row instead.
50+
erow = erow - 1
51+
ecol = math.max(1, vim.fn.col({ erow, "$" }) - 1)
52+
end
53+
return {
54+
first_pos = { srow, scol },
55+
last_pos = { erow, ecol },
56+
}
57+
end
58+
1759
-- Finds the nearest selection of a given Tree-sitter node type or types.
1860
---@param node_types string|string[] The Tree-sitter node type(s) to be retrieved.
1961
---@return selection|nil @The selection of the node.
2062
---@nodiscard
2163
M.get_selection = function(node_types)
64+
local utils = require("nvim-surround.utils")
65+
2266
if type(node_types) == "string" then
2367
node_types = { node_types }
2468
end
2569

26-
local utils = require("nvim-surround.utils")
27-
local ok, ts_utils = pcall(require, "nvim-treesitter.ts_utils")
28-
if not ok then
70+
local root = M.get_root()
71+
if root == nil then
2972
return nil
3073
end
31-
-- Find the root node of the given buffer
32-
local root = ts_utils.get_node_at_cursor()
33-
if not root then
34-
return {}
35-
end
36-
while root:parent() do
37-
root = root:parent()
38-
end
3974
-- DFS through the tree and find all nodes that have the given type
4075
local stack = { root }
4176
local nodes, selections_list = {}, {}
@@ -46,15 +81,8 @@ M.get_selection = function(node_types)
4681
-- Add the current node to the stack
4782
nodes[#nodes + 1] = cur
4883
-- Compute the node's selection and add it to the list
49-
local range = { ts_utils.get_vim_range({ cur:range() }) }
50-
selections_list[#selections_list + 1] = {
51-
left = {
52-
first_pos = { range[1], range[2] },
53-
},
54-
right = {
55-
last_pos = { range[3], range[4] },
56-
},
57-
}
84+
local selection = M.get_node_selection(cur)
85+
selections_list[#selections_list + 1] = selection
5886
end
5987
-- Pop off of the stack
6088
stack[#stack] = nil

0 commit comments

Comments
 (0)