1- local utils = require (" nvim-surround.utils" )
2- local ts_query = require (" nvim-treesitter.query" )
3- local ts_utils = require (" nvim-treesitter.ts_utils" )
4- local ts_parsers = require (" nvim-treesitter.parsers" )
5-
61local M = {}
72
3+ -- Some compatibility shims over the builtin `vim.treesitter` functions
4+ local get_query = vim .treesitter .get_query or vim .treesitter .query .get
5+
86-- Retrieves the node that corresponds exactly to a given selection.
97--- @param selection selection The given selection.
10- --- @return _ @The corresponding node.
8+ --- @return TSNode | nil @The corresponding node.
119--- @nodiscard
1210M .get_node = function (selection )
13- -- Convert the selection into a list
14- local range = {
15- selection .first_pos [1 ],
16- selection .first_pos [2 ],
17- selection .last_pos [1 ],
18- selection .last_pos [2 ],
19- }
11+ local treesitter = require (" nvim-surround.treesitter" )
2012
21- -- Get the root node of the current tree
22- local lang_tree = ts_parsers . get_parser ( 0 )
23- local tree = lang_tree : trees ()[ 1 ]
24- local root = tree : root ()
13+ local root = treesitter . get_root ()
14+ if root == nil then
15+ return nil
16+ end
2517 -- DFS through the tree and find all nodes that have the given type
2618 local stack = { root }
2719 while # stack > 0 do
2820 local cur = stack [# stack ]
2921 -- If the current node's range is equal to the desired selection, return the node
30- if vim .deep_equal (range , { ts_utils . get_vim_range ({ cur : range () }) } ) then
22+ if vim .deep_equal (selection , treesitter . get_node_selection ( cur ) ) then
3123 return cur
3224 end
3325 -- Pop off of the stack
@@ -40,59 +32,48 @@ M.get_node = function(selection)
4032 return nil
4133end
4234
43- -- Filters an existing parent selection down to a capture.
44- --- @param sexpr string The given S-expression containing the capture.
45- --- @param capture string The name of the capture to be returned.
46- --- @param parent_selection selection The parent selection to be filtered down.
47- M .filter_selection = function (sexpr , capture , parent_selection )
48- local parent_node = M .get_node (parent_selection )
49-
50- local range = { ts_utils .get_vim_range ({ parent_node :range () }) }
51- local lang_tree = ts_parsers .get_parser (0 )
52- local ok , parsed_query = pcall (function ()
53- return vim .treesitter .query .parse and vim .treesitter .query .parse (lang_tree :lang (), sexpr )
54- or vim .treesitter .parse_query (lang_tree :lang (), sexpr )
55- end )
56- if not ok or not parent_node then
57- return {}
58- end
59-
60- for id , node in parsed_query :iter_captures (parent_node , 0 , 0 , - 1 ) do
61- local name = parsed_query .captures [id ]
62- if name == capture then
63- range = { ts_utils .get_vim_range ({ node :range () }) }
64- return {
65- first_pos = { range [1 ], range [2 ] },
66- last_pos = { range [3 ], range [4 ] },
67- }
68- end
69- end
70- return nil
71- end
72-
7335-- Finds the nearest selection of a given query capture and its source.
7436--- @param capture string The capture to be retrieved.
7537--- @param type string The type of query to get the capture from.
7638--- @return selection | nil @The selection of the capture.
7739--- @nodiscard
7840M .get_selection = function (capture , type )
79- -- Get a table of all nodes that match the query
80- local table_list = ts_query .get_capture_matches_recursively (0 , capture , type )
81- -- Convert the list of nodes into a list of selections
41+ local utils = require (" nvim-surround.utils" )
42+ local treesitter = require (" nvim-surround.treesitter" )
43+
44+ local root = treesitter .get_root ()
45+ local query = get_query (vim .bo .filetype , type )
46+ if root == nil or query == nil then
47+ return nil
48+ end
49+
50+ -- Get a list of all selections in the query that match the capture group
8251 local selections_list = {}
83- for _ , tab in ipairs (table_list ) do
84- local range = { ts_utils .get_vim_range ({ tab .node :range () }) }
85- selections_list [# selections_list + 1 ] = {
86- left = {
87- first_pos = { range [1 ], range [2 ] },
88- last_pos = { range [3 ], range [4 ] },
89- },
90- right = {
91- first_pos = { range [3 ], range [4 ] + 1 },
92- last_pos = { range [3 ], range [4 ] },
93- },
94- }
52+ for id , node in query :iter_captures (root , 0 ) do
53+ local name = query .captures [id ]
54+ -- TODO: Figure out why sometimes the name from a capture group like `@call.outer` is missing the `@`
55+ if capture :sub (1 , 1 ) == " @" then
56+ capture = capture :sub (1 - capture :len ())
57+ end
58+
59+ if name == capture then
60+ local selection = treesitter .get_node_selection (node )
61+
62+ local range =
63+ { selection .first_pos [1 ], selection .first_pos [2 ], selection .last_pos [1 ], selection .last_pos [2 ] }
64+ selections_list [# selections_list + 1 ] = {
65+ left = {
66+ first_pos = { range [1 ], range [2 ] },
67+ last_pos = { range [3 ], range [4 ] },
68+ },
69+ right = {
70+ first_pos = { range [3 ], range [4 ] + 1 },
71+ last_pos = { range [3 ], range [4 ] },
72+ },
73+ }
74+ end
9575 end
76+
9677 -- Filter out the best pair of selections from the list
9778 local best_selections = utils .filter_selections_list (selections_list )
9879 return best_selections
0 commit comments