Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ Surround selections, stylishly :sunglasses:
## :lock: Requirements

- [Neovim 0.8+](https://github.com/neovim/neovim/releases)
- \[Recommended] If
[nvim-treesitter](https://github.com/nvim-treesitter/nvim-treesitter) is
installed, then Tree-sitter nodes may be surrounded and modified, in addition
to just Vim motions and Lua patterns
- \[Recommended] If
[nvim-treesitter-textobjects](https://github.com/nvim-treesitter/nvim-treesitter-textobjects)
is installed, then Tree-sitter text-objects can be used to define surrounds,
Expand Down
20 changes: 10 additions & 10 deletions lua/nvim-surround/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,16 @@ M.default_opts = {
end
end,
find = function()
if vim.g.loaded_nvim_treesitter then
local selection = M.get_selection({
query = {
capture = "@call.outer",
type = "textobjects",
},
})
if selection then
return selection
end
local selection = M.get_selection({
query = {
capture = "@call.outer",
type = "textobjects",
},
})

-- We prioritize TreeSitter-based selections if they exist, otherwise fallback on pattern-based search
if selection then
return selection
end
return M.get_selection({ pattern = "[^=%s%(%){}]+%b()" })
end,
Expand Down
107 changes: 44 additions & 63 deletions lua/nvim-surround/queries.lua
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
local utils = require("nvim-surround.utils")
local ts_query = require("nvim-treesitter.query")
local ts_utils = require("nvim-treesitter.ts_utils")
local ts_parsers = require("nvim-treesitter.parsers")

local M = {}

-- Some compatibility shims over the builtin `vim.treesitter` functions
local get_query = vim.treesitter.get_query or vim.treesitter.query.get

-- Retrieves the node that corresponds exactly to a given selection.
---@param selection selection The given selection.
---@return _ @The corresponding node.
---@return TSNode|nil @The corresponding node.
---@nodiscard
M.get_node = function(selection)
-- Convert the selection into a list
local range = {
selection.first_pos[1],
selection.first_pos[2],
selection.last_pos[1],
selection.last_pos[2],
}
local treesitter = require("nvim-surround.treesitter")

-- Get the root node of the current tree
local lang_tree = ts_parsers.get_parser(0)
local tree = lang_tree:trees()[1]
local root = tree:root()
local root = treesitter.get_root()
if root == nil then
return nil
end
-- DFS through the tree and find all nodes that have the given type
local stack = { root }
while #stack > 0 do
local cur = stack[#stack]
-- If the current node's range is equal to the desired selection, return the node
if vim.deep_equal(range, { ts_utils.get_vim_range({ cur:range() }) }) then
if vim.deep_equal(selection, treesitter.get_node_selection(cur)) then
return cur
end
-- Pop off of the stack
Expand All @@ -40,59 +32,48 @@ M.get_node = function(selection)
return nil
end

-- Filters an existing parent selection down to a capture.
---@param sexpr string The given S-expression containing the capture.
---@param capture string The name of the capture to be returned.
---@param parent_selection selection The parent selection to be filtered down.
M.filter_selection = function(sexpr, capture, parent_selection)
local parent_node = M.get_node(parent_selection)

local range = { ts_utils.get_vim_range({ parent_node:range() }) }
local lang_tree = ts_parsers.get_parser(0)
local ok, parsed_query = pcall(function()
return vim.treesitter.query.parse and vim.treesitter.query.parse(lang_tree:lang(), sexpr)
or vim.treesitter.parse_query(lang_tree:lang(), sexpr)
end)
if not ok or not parent_node then
return {}
end

for id, node in parsed_query:iter_captures(parent_node, 0, 0, -1) do
local name = parsed_query.captures[id]
if name == capture then
range = { ts_utils.get_vim_range({ node:range() }) }
return {
first_pos = { range[1], range[2] },
last_pos = { range[3], range[4] },
}
end
end
return nil
end

-- Finds the nearest selection of a given query capture and its source.
---@param capture string The capture to be retrieved.
---@param type string The type of query to get the capture from.
---@return selection|nil @The selection of the capture.
---@nodiscard
M.get_selection = function(capture, type)
-- Get a table of all nodes that match the query
local table_list = ts_query.get_capture_matches_recursively(0, capture, type)
-- Convert the list of nodes into a list of selections
local utils = require("nvim-surround.utils")
local treesitter = require("nvim-surround.treesitter")

local root = treesitter.get_root()
local query = get_query(vim.bo.filetype, type)
if root == nil or query == nil then
return nil
end

-- Get a list of all selections in the query that match the capture group
local selections_list = {}
for _, tab in ipairs(table_list) do
local range = { ts_utils.get_vim_range({ tab.node:range() }) }
selections_list[#selections_list + 1] = {
left = {
first_pos = { range[1], range[2] },
last_pos = { range[3], range[4] },
},
right = {
first_pos = { range[3], range[4] + 1 },
last_pos = { range[3], range[4] },
},
}
for id, node in query:iter_captures(root, 0) do
local name = query.captures[id]
-- TODO: Figure out why sometimes the name from a capture group like `@call.outer` is missing the `@`
if capture:sub(1, 1) == "@" then
capture = capture:sub(1 - capture:len())
end

if name == capture then
local selection = treesitter.get_node_selection(node)

local range =
{ selection.first_pos[1], selection.first_pos[2], selection.last_pos[1], selection.last_pos[2] }
selections_list[#selections_list + 1] = {
left = {
first_pos = { range[1], range[2] },
last_pos = { range[3], range[4] },
},
right = {
first_pos = { range[3], range[4] + 1 },
last_pos = { range[3], range[4] },
},
}
end
end

-- Filter out the best pair of selections from the list
local best_selections = utils.filter_selections_list(selections_list)
return best_selections
Expand Down
89 changes: 69 additions & 20 deletions lua/nvim-surround/treesitter.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
local M = {}

-- Gets a TreeSitter node from the buffer. A thin compatibility shim over `vim.treesitter.get_node`.
---@param opts vim.treesitter.get_node.Opts
---@return TSNode|nil @The smallest named node at the given position.
---@nodiscard
local get_node = function(opts)
if vim.treesitter.get_node_at_pos then
local buffer = require("nvim-surround.buffer")

opts = opts or {}
-- The input position is (0, 0) indexed
local pos = opts.pos or { buffer.get_curpos()[1] - 1, buffer.get_curpos()[2] - 1 }
local bufnr = opts.bufnr or 0
return vim.treesitter.get_node_at_pos(bufnr, pos[1], pos[2], opts)
end

return vim.treesitter.get_node(opts)
end

-- Returns whether or not a target node type is found in a list of types.
---@param target string The target type to be found.
---@param types string[] The list of types to search through.
Expand All @@ -14,28 +32,66 @@ local function is_any_of(target, types)
return false
end

-- Gets the root node for the buffer.
---@return TSNode|nil @The root node for the buffer.
---@nodiscard
M.get_root = function()
local node = get_node()
if node == nil then
return nil
end

while node:parent() ~= nil do
node = node:parent()
end

return node
end

-- Gets the current smallest node at the cursor.
---@return TSNode|nil @The smallest node containing the cursor, in the current buffer.
---@nodiscard
M.get_node_at_cursor = function()
return get_node({ ignore_injections = false })
end

-- Gets the selection that a TreeSitter node spans.
---@param node TSNode The given TreeSitter node.
---@return selection @The span of the input node.
---@nodiscard
M.get_node_selection = function(node)
---@type integer, integer, integer, integer
local srow, scol, erow, ecol = node:range()
srow = srow + 1
scol = scol + 1
erow = erow + 1

if ecol == 0 then
-- Use the value of the last col of the previous row instead.
erow = erow - 1
ecol = math.max(1, vim.fn.col({ erow, "$" }) - 1)
end
return {
first_pos = { srow, scol },
last_pos = { erow, ecol },
}
end

-- Finds the nearest selection of a given Tree-sitter node type or types.
---@param node_types string|string[] The Tree-sitter node type(s) to be retrieved.
---@return selection|nil @The selection of the node.
---@nodiscard
M.get_selection = function(node_types)
local utils = require("nvim-surround.utils")

if type(node_types) == "string" then
node_types = { node_types }
end

local utils = require("nvim-surround.utils")
local ok, ts_utils = pcall(require, "nvim-treesitter.ts_utils")
if not ok then
local root = M.get_root()
if root == nil then
return nil
end
-- Find the root node of the given buffer
local root = ts_utils.get_node_at_cursor()
if not root then
return {}
end
while root:parent() do
root = root:parent()
end
-- DFS through the tree and find all nodes that have the given type
local stack = { root }
local nodes, selections_list = {}, {}
Expand All @@ -46,15 +102,8 @@ M.get_selection = function(node_types)
-- Add the current node to the stack
nodes[#nodes + 1] = cur
-- Compute the node's selection and add it to the list
local range = { ts_utils.get_vim_range({ cur:range() }) }
selections_list[#selections_list + 1] = {
left = {
first_pos = { range[1], range[2] },
},
right = {
last_pos = { range[3], range[4] },
},
}
local selection = M.get_node_selection(cur)
selections_list[#selections_list + 1] = selection
end
-- Pop off of the stack
stack[#stack] = nil
Expand Down
Loading