1- local utils = require (" nvim-surround.utils" )
2- local ts_query = require (" nvim-treesitter.query" )
3- local ts_utils = require (" nvim-treesitter.ts_utils" )
4- local ts_parsers = require (" nvim-treesitter.parsers" )
5-
61local M = {}
72
83-- Retrieves the node that corresponds exactly to a given selection.
94--- @param selection selection The given selection.
10- --- @return _ @The corresponding node.
5+ --- @return TSNode | nil @The corresponding node.
116--- @nodiscard
127M .get_node = function (selection )
13- -- Convert the selection into a list
14- local range = {
15- selection .first_pos [1 ],
16- selection .first_pos [2 ],
17- selection .last_pos [1 ],
18- selection .last_pos [2 ],
19- }
8+ local treesitter = require (" nvim-surround.treesitter" )
209
21- -- Get the root node of the current tree
22- local lang_tree = ts_parsers . get_parser ( 0 )
23- local tree = lang_tree : trees ()[ 1 ]
24- local root = tree : root ()
10+ local root = treesitter . get_root ()
11+ if root == nil then
12+ return nil
13+ end
2514 -- DFS through the tree and find all nodes that have the given type
2615 local stack = { root }
2716 while # stack > 0 do
2817 local cur = stack [# stack ]
2918 -- If the current node's range is equal to the desired selection, return the node
30- if vim .deep_equal (range , { ts_utils . get_vim_range ({ cur : range () }) } ) then
19+ if vim .deep_equal (selection , treesitter . get_node_selection ( cur ) ) then
3120 return cur
3221 end
3322 -- Pop off of the stack
@@ -40,48 +29,24 @@ M.get_node = function(selection)
4029 return nil
4130end
4231
43- -- Filters an existing parent selection down to a capture.
44- --- @param sexpr string The given S-expression containing the capture.
45- --- @param capture string The name of the capture to be returned.
46- --- @param parent_selection selection The parent selection to be filtered down.
47- M .filter_selection = function (sexpr , capture , parent_selection )
48- local parent_node = M .get_node (parent_selection )
49-
50- local range = { ts_utils .get_vim_range ({ parent_node :range () }) }
51- local lang_tree = ts_parsers .get_parser (0 )
52- local ok , parsed_query = pcall (function ()
53- return vim .treesitter .query .parse and vim .treesitter .query .parse (lang_tree :lang (), sexpr )
54- or vim .treesitter .parse_query (lang_tree :lang (), sexpr )
55- end )
56- if not ok or not parent_node then
57- return {}
58- end
59-
60- for id , node in parsed_query :iter_captures (parent_node , 0 , 0 , - 1 ) do
61- local name = parsed_query .captures [id ]
62- if name == capture then
63- range = { ts_utils .get_vim_range ({ node :range () }) }
64- return {
65- first_pos = { range [1 ], range [2 ] },
66- last_pos = { range [3 ], range [4 ] },
67- }
68- end
69- end
70- return nil
71- end
72-
7332-- Finds the nearest selection of a given query capture and its source.
7433--- @param capture string The capture to be retrieved.
7534--- @param type string The type of query to get the capture from.
7635--- @return selection | nil @The selection of the capture.
7736--- @nodiscard
7837M .get_selection = function (capture , type )
38+ local utils = require (" nvim-surround.utils" )
39+ local treesitter = require (" nvim-surround.treesitter" )
40+ local ts_query = require (" nvim-treesitter.query" )
41+
7942 -- Get a table of all nodes that match the query
8043 local table_list = ts_query .get_capture_matches_recursively (0 , capture , type )
8144 -- Convert the list of nodes into a list of selections
8245 local selections_list = {}
8346 for _ , tab in ipairs (table_list ) do
84- local range = { ts_utils .get_vim_range ({ tab .node :range () }) }
47+ local selection = treesitter .get_node_selection (tab .node )
48+
49+ local range = { selection .first_pos [1 ], selection .first_pos [2 ], selection .last_pos [1 ], selection .last_pos [2 ] }
8550 selections_list [# selections_list + 1 ] = {
8651 left = {
8752 first_pos = { range [1 ], range [2 ] },
0 commit comments