Skip to content

Commit a925c03

Browse files
authored
fix: Use native vim.treesitter module when available. (#391)
* fix: Use native `vim.treesitter` module when available. * docs: Improve docstring, remove `nvim-treesitter` from README.
1 parent 8dd9150 commit a925c03

File tree

4 files changed

+123
-97
lines changed

4 files changed

+123
-97
lines changed

README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@ Surround selections, stylishly :sunglasses:
2424
## :lock: Requirements
2525

2626
- [Neovim 0.8+](https://github.com/neovim/neovim/releases)
27-
- \[Recommended] If
28-
[nvim-treesitter](https://github.com/nvim-treesitter/nvim-treesitter) is
29-
installed, then Tree-sitter nodes may be surrounded and modified, in addition
30-
to just Vim motions and Lua patterns
3127
- \[Recommended] If
3228
[nvim-treesitter-textobjects](https://github.com/nvim-treesitter/nvim-treesitter-textobjects)
3329
is installed, then Tree-sitter text-objects can be used to define surrounds,

lua/nvim-surround/config.lua

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,16 @@ M.default_opts = {
178178
end
179179
end,
180180
find = function()
181-
if vim.g.loaded_nvim_treesitter then
182-
local selection = M.get_selection({
183-
query = {
184-
capture = "@call.outer",
185-
type = "textobjects",
186-
},
187-
})
188-
if selection then
189-
return selection
190-
end
181+
local selection = M.get_selection({
182+
query = {
183+
capture = "@call.outer",
184+
type = "textobjects",
185+
},
186+
})
187+
188+
-- We prioritize TreeSitter-based selections if they exist, otherwise fallback on pattern-based search
189+
if selection then
190+
return selection
191191
end
192192
return M.get_selection({ pattern = "[^=%s%(%){}]+%b()" })
193193
end,

lua/nvim-surround/queries.lua

Lines changed: 44 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,25 @@
1-
local utils = require("nvim-surround.utils")
2-
local ts_query = require("nvim-treesitter.query")
3-
local ts_utils = require("nvim-treesitter.ts_utils")
4-
local ts_parsers = require("nvim-treesitter.parsers")
5-
61
local M = {}
72

3+
-- Some compatibility shims over the builtin `vim.treesitter` functions
4+
local get_query = vim.treesitter.get_query or vim.treesitter.query.get
5+
86
-- Retrieves the node that corresponds exactly to a given selection.
97
---@param selection selection The given selection.
10-
---@return _ @The corresponding node.
8+
---@return TSNode|nil @The corresponding node.
119
---@nodiscard
1210
M.get_node = function(selection)
13-
-- Convert the selection into a list
14-
local range = {
15-
selection.first_pos[1],
16-
selection.first_pos[2],
17-
selection.last_pos[1],
18-
selection.last_pos[2],
19-
}
11+
local treesitter = require("nvim-surround.treesitter")
2012

21-
-- Get the root node of the current tree
22-
local lang_tree = ts_parsers.get_parser(0)
23-
local tree = lang_tree:trees()[1]
24-
local root = tree:root()
13+
local root = treesitter.get_root()
14+
if root == nil then
15+
return nil
16+
end
2517
-- DFS through the tree and find all nodes that have the given type
2618
local stack = { root }
2719
while #stack > 0 do
2820
local cur = stack[#stack]
2921
-- If the current node's range is equal to the desired selection, return the node
30-
if vim.deep_equal(range, { ts_utils.get_vim_range({ cur:range() }) }) then
22+
if vim.deep_equal(selection, treesitter.get_node_selection(cur)) then
3123
return cur
3224
end
3325
-- Pop off of the stack
@@ -40,59 +32,48 @@ M.get_node = function(selection)
4032
return nil
4133
end
4234

43-
-- Filters an existing parent selection down to a capture.
44-
---@param sexpr string The given S-expression containing the capture.
45-
---@param capture string The name of the capture to be returned.
46-
---@param parent_selection selection The parent selection to be filtered down.
47-
M.filter_selection = function(sexpr, capture, parent_selection)
48-
local parent_node = M.get_node(parent_selection)
49-
50-
local range = { ts_utils.get_vim_range({ parent_node:range() }) }
51-
local lang_tree = ts_parsers.get_parser(0)
52-
local ok, parsed_query = pcall(function()
53-
return vim.treesitter.query.parse and vim.treesitter.query.parse(lang_tree:lang(), sexpr)
54-
or vim.treesitter.parse_query(lang_tree:lang(), sexpr)
55-
end)
56-
if not ok or not parent_node then
57-
return {}
58-
end
59-
60-
for id, node in parsed_query:iter_captures(parent_node, 0, 0, -1) do
61-
local name = parsed_query.captures[id]
62-
if name == capture then
63-
range = { ts_utils.get_vim_range({ node:range() }) }
64-
return {
65-
first_pos = { range[1], range[2] },
66-
last_pos = { range[3], range[4] },
67-
}
68-
end
69-
end
70-
return nil
71-
end
72-
7335
-- Finds the nearest selection of a given query capture and its source.
7436
---@param capture string The capture to be retrieved.
7537
---@param type string The type of query to get the capture from.
7638
---@return selection|nil @The selection of the capture.
7739
---@nodiscard
7840
M.get_selection = function(capture, type)
79-
-- Get a table of all nodes that match the query
80-
local table_list = ts_query.get_capture_matches_recursively(0, capture, type)
81-
-- Convert the list of nodes into a list of selections
41+
local utils = require("nvim-surround.utils")
42+
local treesitter = require("nvim-surround.treesitter")
43+
44+
local root = treesitter.get_root()
45+
local query = get_query(vim.bo.filetype, type)
46+
if root == nil or query == nil then
47+
return nil
48+
end
49+
50+
-- Get a list of all selections in the query that match the capture group
8251
local selections_list = {}
83-
for _, tab in ipairs(table_list) do
84-
local range = { ts_utils.get_vim_range({ tab.node:range() }) }
85-
selections_list[#selections_list + 1] = {
86-
left = {
87-
first_pos = { range[1], range[2] },
88-
last_pos = { range[3], range[4] },
89-
},
90-
right = {
91-
first_pos = { range[3], range[4] + 1 },
92-
last_pos = { range[3], range[4] },
93-
},
94-
}
52+
for id, node in query:iter_captures(root, 0) do
53+
local name = query.captures[id]
54+
-- TODO: Figure out why sometimes the name from a capture group like `@call.outer` is missing the `@`
55+
if capture:sub(1, 1) == "@" then
56+
capture = capture:sub(1 - capture:len())
57+
end
58+
59+
if name == capture then
60+
local selection = treesitter.get_node_selection(node)
61+
62+
local range =
63+
{ selection.first_pos[1], selection.first_pos[2], selection.last_pos[1], selection.last_pos[2] }
64+
selections_list[#selections_list + 1] = {
65+
left = {
66+
first_pos = { range[1], range[2] },
67+
last_pos = { range[3], range[4] },
68+
},
69+
right = {
70+
first_pos = { range[3], range[4] + 1 },
71+
last_pos = { range[3], range[4] },
72+
},
73+
}
74+
end
9575
end
76+
9677
-- Filter out the best pair of selections from the list
9778
local best_selections = utils.filter_selections_list(selections_list)
9879
return best_selections

lua/nvim-surround/treesitter.lua

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
local M = {}
22

3+
-- Gets a TreeSitter node from the buffer. A thin compatibility shim over `vim.treesitter.get_node`.
4+
---@param opts vim.treesitter.get_node.Opts
5+
---@return TSNode|nil @The smallest named node at the given position.
6+
---@nodiscard
7+
local get_node = function(opts)
8+
if vim.treesitter.get_node_at_pos then
9+
local buffer = require("nvim-surround.buffer")
10+
11+
opts = opts or {}
12+
-- The input position is (0, 0) indexed
13+
local pos = opts.pos or { buffer.get_curpos()[1] - 1, buffer.get_curpos()[2] - 1 }
14+
local bufnr = opts.bufnr or 0
15+
return vim.treesitter.get_node_at_pos(bufnr, pos[1], pos[2], opts)
16+
end
17+
18+
return vim.treesitter.get_node(opts)
19+
end
20+
321
-- Returns whether or not a target node type is found in a list of types.
422
---@param target string The target type to be found.
523
---@param types string[] The list of types to search through.
@@ -14,28 +32,66 @@ local function is_any_of(target, types)
1432
return false
1533
end
1634

35+
-- Gets the root node for the buffer.
36+
---@return TSNode|nil @The root node for the buffer.
37+
---@nodiscard
38+
M.get_root = function()
39+
local node = get_node()
40+
if node == nil then
41+
return nil
42+
end
43+
44+
while node:parent() ~= nil do
45+
node = node:parent()
46+
end
47+
48+
return node
49+
end
50+
51+
-- Gets the current smallest node at the cursor.
52+
---@return TSNode|nil @The smallest node containing the cursor, in the current buffer.
53+
---@nodiscard
54+
M.get_node_at_cursor = function()
55+
return get_node({ ignore_injections = false })
56+
end
57+
58+
-- Gets the selection that a TreeSitter node spans.
59+
---@param node TSNode The given TreeSitter node.
60+
---@return selection @The span of the input node.
61+
---@nodiscard
62+
M.get_node_selection = function(node)
63+
---@type integer, integer, integer, integer
64+
local srow, scol, erow, ecol = node:range()
65+
srow = srow + 1
66+
scol = scol + 1
67+
erow = erow + 1
68+
69+
if ecol == 0 then
70+
-- Use the value of the last col of the previous row instead.
71+
erow = erow - 1
72+
ecol = math.max(1, vim.fn.col({ erow, "$" }) - 1)
73+
end
74+
return {
75+
first_pos = { srow, scol },
76+
last_pos = { erow, ecol },
77+
}
78+
end
79+
1780
-- Finds the nearest selection of a given Tree-sitter node type or types.
1881
---@param node_types string|string[] The Tree-sitter node type(s) to be retrieved.
1982
---@return selection|nil @The selection of the node.
2083
---@nodiscard
2184
M.get_selection = function(node_types)
85+
local utils = require("nvim-surround.utils")
86+
2287
if type(node_types) == "string" then
2388
node_types = { node_types }
2489
end
2590

26-
local utils = require("nvim-surround.utils")
27-
local ok, ts_utils = pcall(require, "nvim-treesitter.ts_utils")
28-
if not ok then
91+
local root = M.get_root()
92+
if root == nil then
2993
return nil
3094
end
31-
-- Find the root node of the given buffer
32-
local root = ts_utils.get_node_at_cursor()
33-
if not root then
34-
return {}
35-
end
36-
while root:parent() do
37-
root = root:parent()
38-
end
3995
-- DFS through the tree and find all nodes that have the given type
4096
local stack = { root }
4197
local nodes, selections_list = {}, {}
@@ -46,15 +102,8 @@ M.get_selection = function(node_types)
46102
-- Add the current node to the stack
47103
nodes[#nodes + 1] = cur
48104
-- Compute the node's selection and add it to the list
49-
local range = { ts_utils.get_vim_range({ cur:range() }) }
50-
selections_list[#selections_list + 1] = {
51-
left = {
52-
first_pos = { range[1], range[2] },
53-
},
54-
right = {
55-
last_pos = { range[3], range[4] },
56-
},
57-
}
105+
local selection = M.get_node_selection(cur)
106+
selections_list[#selections_list + 1] = selection
58107
end
59108
-- Pop off of the stack
60109
stack[#stack] = nil

0 commit comments

Comments
 (0)