Skip to content
Open
74 changes: 31 additions & 43 deletions lua/cmp/config/compare.lua
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ compare.locality = setmetatable({
---scopes: Entries defined in a closer scope will be ranked higher (e.g., prefer local variables to globals).
---@type cmp.ComparatorFunctor
compare.scopes = setmetatable({
scopes_map = {},
definition_depths = {},
has_nvim_0_9_features = vim.fn.has('nvim-0.9') == 1,
update = function(self)
local config = require('cmp').get_config()
if not vim.tbl_contains(config.sorting.comparators, compare.scopes) then
Expand All @@ -207,64 +208,51 @@ compare.scopes = setmetatable({

local ok, locals = pcall(require, 'nvim-treesitter.locals')
if ok then
local win, buf = vim.api.nvim_get_current_win(), vim.api.nvim_get_current_buf()
local cursor_row = vim.api.nvim_win_get_cursor(win)[1] - 1

-- Cursor scope.
local cursor_scope = nil
-- Prioritize the older get_scopes method from nvim-treesitter `master` over get from `main`
local scopes = locals.get_scopes and locals.get_scopes(buf) or select(3, locals.get(buf))
for _, scope in ipairs(scopes) do
if scope:start() <= cursor_row and cursor_row <= scope:end_() then
if not cursor_scope then
cursor_scope = scope
else
if cursor_scope:start() <= scope:start() and scope:end_() <= cursor_scope:end_() then
cursor_scope = scope
end
end
elseif cursor_scope and cursor_scope:end_() <= scope:start() then
break
end
self.definition_depths = {}
local buf = vim.api.nvim_get_current_buf()
if self.has_nvim_0_9_features and not vim.b[buf].cmp_buf_has_ts_parser then
return
end

-- Definitions.
local definitions = locals.get_definitions_lookup_table(buf)

-- Narrow definitions.
local get_cursor_node = vim.treesitter.get_node or require('nvim-treesitter.ts_utils').get_node_at_cursor
local cursor_node = get_cursor_node()
local scope_depths = {}
local depth = 0
for scope in locals.iter_scope_tree(cursor_scope, buf) do
local s, e = scope:start(), scope:end_()
-- If there's no cursor node, no iterations are made.
---@diagnostic disable-next-line: param-type-mismatch
for scope in locals.iter_scope_tree(cursor_node, buf) do
scope_depths[scope:id()] = depth
depth = depth + 1
end

-- Check scope's direct child.
for _, definition in pairs(definitions) do
if s <= definition.node:start() and definition.node:end_() <= e then
if scope:id() == locals.containing_scope(definition.node, buf):id() then
local get_node_text = vim.treesitter.get_node_text or vim.treesitter.query.get_node_text
local text = get_node_text(definition.node, buf) or ''
if not self.scopes_map[text] then
self.scopes_map[text] = depth
end
end
-- Map definitions based on their scope relative to the cursor.
local definitions = locals.get_definitions_lookup_table(buf)
local get_node_text = vim.treesitter.get_node_text or vim.treesitter.query.get_node_text
for _, definition in pairs(definitions) do
local definition_depth = scope_depths[locals.containing_scope(definition.node, buf):id()]
local def_text = get_node_text(definition.node, buf) or ''
if definition_depth then
-- Prefer the closest scoped definitions.
if not self.definition_depths[def_text] or self.definition_depths[def_text] > definition_depth then
self.definition_depths[def_text] = definition_depth
end
end
depth = depth + 1
end
end
end,
}, {
---@type fun(self: table, entry1: cmp.Entry, entry2: cmp.Entry): boolean|nil
__call = function(self, entry1, entry2)
local local1 = self.scopes_map[entry1.word]
local local2 = self.scopes_map[entry2.word]
if local1 ~= local2 then
if local1 == nil then
local def_depth1 = self.definition_depths[entry1.word]
local def_depth2 = self.definition_depths[entry2.word]
if def_depth1 ~= def_depth2 then
if def_depth1 == nil then
return false
end
if local2 == nil then
if def_depth2 == nil then
return true
end
return local1 < local2
return def_depth1 < def_depth2
end
end,
})
Expand Down
9 changes: 5 additions & 4 deletions lua/cmp/utils/autocmd.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ local function create_autocmd(event)
vim.api.nvim_create_autocmd(event, {
desc = ('nvim-cmp: autocmd: %s'):format(event),
group = autocmd.group,
callback = function()
autocmd.emit(event)
callback = function(details)
autocmd.emit(event, details)
end,
})
end
Expand Down Expand Up @@ -45,12 +45,13 @@ end

---Emit autocmd
---@param event string
autocmd.emit = function(event)
---@param details table|nil
autocmd.emit = function(event, details)
debug.log(' ')
debug.log(string.format('>>> %s', event))
autocmd.events[event] = autocmd.events[event] or {}
for _, callback in ipairs(autocmd.events[event]) do
callback()
callback(details)
end
end

Expand Down
24 changes: 24 additions & 0 deletions plugin/cmp.lua
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,30 @@ if vim.on_key then
end, vim.api.nvim_create_namespace('cmp.plugin'))
end

-- see compare.scopes
if vim.fn.has('nvim-0.9') == 1 then
local ts = vim.treesitter
local has_ts_parser = ts.language.get_lang
-- vim.treesitter.language.add is recommended for checking treesitter in 0.11 nightly
if vim.fn.has('nvim-0.11') then
has_ts_parser = function(filetype)
local lang = ts.language.get_lang(filetype)
return lang and ts.language.add(lang)
end
end
autocmd.subscribe({ 'FileType' }, function(details)
if has_ts_parser(details.match) then
vim.b[details.buf].cmp_buf_has_ts_parser = true
else
vim.b[details.buf].cmp_buf_has_ts_parser = false
end
end)
autocmd.subscribe({ 'BufUnload' }, function(details)
if vim.treesitter.language.get_lang(details.match) then
vim.b[details.buf].cmp_buf_has_ts_parser = false
end
end)
end

vim.api.nvim_create_user_command('CmpStatus', function()
require('cmp').status()
Expand Down
Loading