aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/common/config.odin1
-rw-r--r--src/server/action.odin5
-rw-r--r--src/server/action_invert_if_statements.odin2
-rw-r--r--src/server/analysis.odin349
-rw-r--r--src/server/collector.odin16
-rw-r--r--src/server/generics.odin21
-rw-r--r--src/server/references.odin52
-rw-r--r--src/server/requests.odin158
-rw-r--r--src/server/types.odin6
-rw-r--r--src/testing/testing.odin3
10 files changed, 358 insertions, 255 deletions
diff --git a/src/common/config.odin b/src/common/config.odin
index f2ae68f..d1bfd31 100644
--- a/src/common/config.odin
+++ b/src/common/config.odin
@@ -41,6 +41,7 @@ Config :: struct {
enable_document_links: bool,
enable_comp_lit_signature_help: bool,
enable_comp_lit_signature_help_use_docs: bool,
+ enable_code_action_invert_if: bool,
disable_parser_errors: bool,
thread_count: int,
file_log: bool,
diff --git a/src/server/action.odin b/src/server/action.odin
index 11e8f86..ce0dbb3 100644
--- a/src/server/action.odin
+++ b/src/server/action.odin
@@ -79,7 +79,10 @@ get_code_actions :: proc(document: ^Document, range: common.Range, config: ^comm
&actions,
)
}
- add_invert_if_action(document, position_context.position, strings.clone(document.uri.uri), &actions)
+
+ if config.enable_code_action_invert_if {
+ add_invert_if_action(document, position_context.position, strings.clone(document.uri.uri), &actions)
+ }
return actions[:], true
}
diff --git a/src/server/action_invert_if_statements.odin b/src/server/action_invert_if_statements.odin
index 047b3a2..336caaa 100644
--- a/src/server/action_invert_if_statements.odin
+++ b/src/server/action_invert_if_statements.odin
@@ -3,10 +3,8 @@
package server
import "core:fmt"
-import "core:log"
import "core:odin/ast"
import "core:odin/tokenizer"
-import path "core:path/slashpath"
import "core:strings"
import "src:common"
diff --git a/src/server/analysis.odin b/src/server/analysis.odin
index 1a85708..e980913 100644
--- a/src/server/analysis.odin
+++ b/src/server/analysis.odin
@@ -697,6 +697,108 @@ should_resolve_all_proc_overload_possibilities :: proc(ast_context: ^AstContext,
return ast_context.position_hint == .Completion || ast_context.position_hint == .SignatureHelp || call_expr == nil
}
+CallArg :: struct {
+ symbol: Symbol,
+ implicit_selector: ^ast.Implicit_Selector_Expr,
+ name: string,
+ named: bool,
+ is_nil: bool,
+ bad_expr: bool,
+ has_symbol: bool,
+ is_poly_type: bool,
+}
+
+expand_call_args :: proc(ast_context: ^AstContext, call: ^ast.Call_Expr) -> ([]CallArg, bool) {
+ results := make([dynamic]CallArg, context.temp_allocator)
+ if call == nil {
+ return results[:], true
+ }
+
+ used_named := false
+ append_arg :: proc(
+ ast_context: ^AstContext,
+ arg: ^ast.Expr,
+ results: ^[dynamic]CallArg,
+ used_named: ^bool,
+ ) -> bool {
+ ast_context.use_locals = true
+
+ call_arg := CallArg{}
+
+ if _, ok := arg.derived.(^ast.Bad_Expr); ok {
+ call_arg.bad_expr = true
+ append(results, call_arg)
+ return true
+ }
+
+ value_expr := arg
+
+ //named parameter
+ if field, ok := arg.derived.(^ast.Field_Value); ok {
+ call_arg.named = true
+ value_expr = field.value
+ used_named^ = true
+
+ if ident, ok := field.field.derived.(^ast.Ident); ok {
+ call_arg.name = ident.name
+ }
+ } else if used_named^ {
+ log.error("Expected name parameter after starting named parmeter phase")
+ return false
+ }
+
+ if ident, ok := value_expr.derived.(^ast.Ident); ok && ident.name == "nil" {
+ call_arg.is_nil = true
+ append(results, call_arg)
+ return true
+ } else if implicit, ok := value_expr.derived.(^ast.Implicit_Selector_Expr); ok {
+ call_arg.implicit_selector = implicit
+ append(results, call_arg)
+ return true
+ }
+
+ if symbol, ok := resolve_call_arg_type_expression(ast_context, value_expr); ok {
+ call_arg.symbol = symbol
+ call_arg.has_symbol = true
+ if _, ok := symbol.value.(SymbolPolyTypeValue); ok {
+ call_arg.is_poly_type = true
+ append(results, call_arg)
+ return true
+ } else if v, ok := symbol.value.(SymbolProcedureValue); ok {
+ if len(v.return_types) == 0 {
+ return false
+ }
+ for arg in v.return_types {
+ expr := arg.type
+ if expr == nil {
+ expr = arg.default_value
+ }
+
+ if !append_arg(ast_context, expr, results, used_named) {
+ return false
+ }
+ }
+ return true
+ } else {
+ append(results, call_arg)
+ return true
+ }
+ } else {
+ return false
+ }
+
+ return true
+ }
+
+ for arg in call.args {
+ if !append_arg(ast_context, arg, &results, &used_named) {
+ return {}, false
+ }
+ }
+
+ return results[:], true
+}
+
/*
Figure out which function the call expression is using out of the list from proc group
*/
@@ -721,12 +823,21 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro
}
resolve_all_possibilities := should_resolve_all_proc_overload_possibilities(ast_context, call_expr)
- call_unnamed_arg_count := 0
- if call_expr != nil {
- call_unnamed_arg_count = get_unnamed_arg_count(call_expr.args)
- }
candidates := make([dynamic]Candidate, context.temp_allocator)
+ call_args, ok := expand_call_args(ast_context, call_expr)
+ if !ok {
+ return {}, false
+ }
+
+ if !resolve_all_possibilities {
+ for arg in call_args {
+ if arg.is_poly_type {
+ resolve_all_possibilities = true
+ break
+ }
+ }
+ }
for arg_expr in group.args {
f := Symbol{}
@@ -735,110 +846,52 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro
symbol = f,
score = 1,
}
- if call_expr == nil || (resolve_all_possibilities && len(call_expr.args) == 0) {
+ if call_expr == nil || (resolve_all_possibilities && len(call_args) == 0) {
append(&candidates, candidate)
break next_fn
}
if procedure, ok := f.value.(SymbolProcedureValue); ok {
i := 0
- named := false
if !resolve_all_possibilities {
arg_count := get_proc_arg_count(procedure)
- if call_expr != nil && arg_count < len(call_expr.args) {
+ if call_expr != nil && arg_count < len(call_args) {
break next_fn
}
- if arg_count == len(call_expr.args) {
+ if arg_count == len(call_args) {
candidate.score /= 2
}
}
for proc_arg in procedure.arg_types {
for name in proc_arg.names {
- if i >= len(call_expr.args) {
+ if i >= len(call_args) {
+ i += 1
continue
}
- call_arg := call_expr.args[i]
+ call_arg := call_args[i]
+ i += 1
ast_context.use_locals = true
- call_symbol: Symbol
arg_symbol: Symbol
ok: bool
- is_call_arg_nil: bool
- implicit_selector: ^ast.Implicit_Selector_Expr
- if _, ok = call_arg.derived.(^ast.Bad_Expr); ok {
+ if call_arg.bad_expr {
continue
}
- //named parameter
- if field, is_field := call_arg.derived.(^ast.Field_Value); is_field {
- named = true
- if ident, is_ident := field.value.derived.(^ast.Ident); is_ident && ident.name == "nil" {
- is_call_arg_nil = true
- ok = true
- } else if implicit, is_implicit := field.value.derived.(^ast.Implicit_Selector_Expr);
- is_implicit {
- implicit_selector = implicit
- ok = true
- } else {
- call_symbol, ok = resolve_call_arg_type_expression(ast_context, field.value)
- if !ok {
- break next_fn
- }
- }
-
- if ident, is_ident := field.field.derived.(^ast.Ident); is_ident {
- i, ok = get_field_list_name_index(
- field.field.derived.(^ast.Ident).name,
- procedure.arg_types,
- )
- } else {
- break next_fn
- }
- } else {
- if named {
- log.error("Expected name parameter after starting named parmeter phase")
- return {}, false
- }
- if ident, is_ident := call_arg.derived.(^ast.Ident); is_ident && ident.name == "nil" {
- is_call_arg_nil = true
- ok = true
- } else if implicit, is_implicit_selector := call_arg.derived.(^ast.Implicit_Selector_Expr);
- is_implicit_selector {
- implicit_selector = implicit
- ok = true
- } else {
- call_symbol, ok = resolve_call_arg_type_expression(ast_context, call_arg)
- }
- }
-
- if !ok {
- break next_fn
- }
-
-
- if p, ok := call_symbol.value.(SymbolProcedureValue); ok {
- if len(p.return_types) != 1 {
- break next_fn
- }
- if s, ok := resolve_call_arg_type_expression(ast_context, p.return_types[0].type); ok {
- call_symbol = s
- }
- }
-
- // If an arg is a parapoly type, we assume it can match any symbol and return all possible
- // matches
- if _, ok := call_symbol.value.(SymbolPolyTypeValue); ok {
- resolve_all_possibilities = true
+ if call_arg.is_poly_type {
continue
}
proc_arg := proc_arg
- if named {
- proc_arg = procedure.arg_types[i]
+ if call_arg.named {
+ proc_arg, ok = get_proc_arg_type_from_name(procedure, call_arg.name)
+ if !ok {
+ break next_fn
+ }
}
if proc_arg.type != nil {
@@ -851,11 +904,17 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro
break next_fn
}
- if implicit_selector != nil {
+ // TODO: check intrinsics for parapoly types?
+ if _, is_poly := arg_symbol.value.(SymbolPolyTypeValue); is_poly {
+ candidate.score += 1
+ continue
+ }
+
+ if call_arg.implicit_selector != nil {
if value, ok := arg_symbol.value.(SymbolEnumValue); ok {
found: bool
for name in value.names {
- if implicit_selector.field.name == name {
+ if call_arg.implicit_selector.field.name == name {
found = true
break
}
@@ -863,27 +922,25 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro
if found {
continue
}
-
}
break next_fn
}
- if is_call_arg_nil {
+ if call_arg.is_nil {
if is_valid_nil_symbol(arg_symbol) {
continue
} else {
break next_fn
}
-
}
- if !is_symbol_same_typed(ast_context, call_symbol, arg_symbol, proc_arg.flags) {
+ if !is_symbol_same_typed(ast_context, call_arg.symbol, arg_symbol, proc_arg.flags) {
found := false
// Are we a union variant
if value, ok := arg_symbol.value.(SymbolUnionValue); ok {
for variant in value.types {
if symbol, ok := resolve_type_expression(ast_context, variant); ok {
- if is_symbol_same_typed(ast_context, call_symbol, symbol, proc_arg.flags) {
+ if is_symbol_same_typed(ast_context, call_arg.symbol, symbol, proc_arg.flags) {
// matching union types are a low priority
candidate.score = 1000000
found = true
@@ -894,11 +951,11 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro
}
// Do we contain a using that matches
- if value, ok := call_symbol.value.(SymbolStructValue); ok {
+ if value, ok := call_arg.symbol.value.(SymbolStructValue); ok {
using_score := 1000000
for k in value.usings {
if symbol, ok := resolve_type_expression(ast_context, value.types[k]); ok {
- symbol.pointers = call_symbol.pointers
+ symbol.pointers = call_arg.symbol.pointers
if is_symbol_same_typed(ast_context, symbol, arg_symbol, proc_arg.flags) {
if k < using_score {
using_score = k
@@ -914,8 +971,6 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro
break next_fn
}
}
-
- i += 1
}
}
@@ -948,14 +1003,14 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro
return {}, false
}
- symbol, ok := get_candidate_symbol(candidates[:], resolve_all_possibilities)
+ symbol, ok_canidate := get_candidate_symbol(candidates[:], resolve_all_possibilities)
if call_expr != nil {
ast_context.call_expr_recursion_cache[cast(rawptr)call_expr] = SymbolResult {
symbol = symbol,
- ok = ok,
+ ok = ok_canidate,
}
}
- return symbol, ok
+ return symbol, ok_canidate
}
resolve_call_arg_type_expression :: proc(ast_context: ^AstContext, node: ^ast.Expr) -> (Symbol, bool) {
@@ -1441,20 +1496,20 @@ resolve_index_expr :: proc(ast_context: ^AstContext, index_expr: ^ast.Index_Expr
#partial switch v in indexed.value {
case SymbolDynamicArrayValue:
if .Soa in indexed.flags {
- indexed.flags |= { .SoaPointer }
+ indexed.flags |= {.SoaPointer}
return indexed, true
}
ok = internal_resolve_type_expression(ast_context, v.expr, &symbol)
case SymbolSliceValue:
ok = internal_resolve_type_expression(ast_context, v.expr, &symbol)
if .Soa in indexed.flags {
- indexed.flags |= { .SoaPointer }
+ indexed.flags |= {.SoaPointer}
return indexed, true
}
case SymbolFixedArrayValue:
ok = internal_resolve_type_expression(ast_context, v.expr, &symbol)
if .Soa in indexed.flags {
- indexed.flags |= { .SoaPointer }
+ indexed.flags |= {.SoaPointer}
return indexed, true
}
case SymbolMapValue:
@@ -1688,10 +1743,31 @@ resolve_selector_expression :: proc(ast_context: ^AstContext, node: ^ast.Selecto
try_build_package(ast_context.current_package)
if node.field != nil {
- return resolve_symbol_return(ast_context, lookup(node.field.name, selector.pkg, node.pos.file))
+ field_symbol, ok := lookup(node.field.name, selector.pkg, node.pos.file)
+ if ok {
+ if pkg_alias_symbol, ok := resolve_field_through_package_alias(ast_context, field_symbol, selector.pkg); ok {
+ return pkg_alias_symbol, true
+ }
+ return resolve_symbol_return(ast_context, field_symbol)
+ }
+ return {}, false
} else {
return Symbol{}, false
}
+ case SymbolBasicValue:
+ if s.ident != nil && node.field != nil {
+ if symbol, ok := resolve_field_access_through_imported_alias(ast_context, s.ident, node); ok {
+ return symbol, true
+ }
+ }
+ case SymbolGenericValue:
+ if s.expr != nil {
+ if ident, ok := s.expr.derived.(^ast.Ident); ok && node.field != nil {
+ if symbol, ok := resolve_field_access_through_imported_alias(ast_context, ident, node); ok {
+ return symbol, true
+ }
+ }
+ }
case SymbolEnumValue:
// enum members probably require own symbol value
selector.type = .EnumMember
@@ -1713,6 +1789,91 @@ resolve_selector_expression :: proc(ast_context: ^AstContext, node: ^ast.Selecto
return {}, false
}
+is_path_package_name :: #force_inline proc(name: string) -> bool {
+ return strings.contains(name, "/")
+}
+
+get_package_ident_from_symbol :: proc(symbol: Symbol) -> (ident: ^ast.Ident, ok: bool) {
+ #partial switch v in symbol.value {
+ case SymbolBasicValue:
+ if v.ident != nil {
+ return v.ident, true
+ }
+ case SymbolGenericValue:
+ if v.expr != nil {
+ if ident, ok := v.expr.derived.(^ast.Ident); ok {
+ return ident, true
+ }
+ }
+ }
+ return nil, false
+}
+
+resolve_ident_as_package :: proc(ast_context: ^AstContext, ident: ^ast.Ident, context_pkg: string) -> (Symbol, bool) {
+ ident_pkg, pkg_ok := internal_resolve_type_identifier(ast_context, ident^)
+ if pkg_ok && ident_pkg.type == .Package {
+ return ident_pkg, true
+ }
+ return {}, false
+}
+
+resolve_field_through_package_alias :: proc(ast_context: ^AstContext, field_symbol: Symbol, context_pkg: string) -> (Symbol, bool) {
+ ident, ok := get_package_ident_from_symbol(field_symbol)
+ if !ok {
+ return {}, false
+ }
+
+ if is_path_package_name(ident.name) {
+ return Symbol{
+ type = .Package,
+ pkg = ident.name,
+ value = SymbolPackageValue{},
+ }, true
+ }
+
+ current_package := ast_context.current_package
+ defer {
+ ast_context.current_package = current_package
+ }
+ ast_context.current_package = context_pkg
+
+ pkg_symbol, pkg_ok := resolve_ident_as_package(ast_context, ident, context_pkg)
+ if pkg_ok && pkg_symbol.type == .Package {
+ return pkg_symbol, true
+ }
+
+ return {}, false
+}
+
+resolve_field_access_through_imported_alias :: proc(ast_context: ^AstContext, ident: ^ast.Ident, node: ^ast.Selector_Expr) -> (Symbol, bool) {
+ for imp in ast_context.imports {
+ if strings.compare(imp.base, ident.name) == 0 {
+ try_build_package(ast_context.current_package)
+ if node.field != nil {
+ symbol, ok := lookup(node.field.name, imp.name, node.pos.file)
+ if ok {
+ return resolve_symbol_return(ast_context, symbol)
+ }
+ }
+ }
+ }
+
+ pkg_symbol, pkg_ok := internal_resolve_type_identifier(ast_context, ident^)
+ if pkg_ok {
+ if _, ok2 := pkg_symbol.value.(SymbolPackageValue); ok2 {
+ try_build_package(ast_context.current_package)
+ if node.field != nil {
+ symbol, ok := lookup(node.field.name, pkg_symbol.pkg, node.pos.file)
+ if ok {
+ return resolve_symbol_return(ast_context, symbol)
+ }
+ }
+ }
+ }
+
+ return {}, false
+}
+
// returns the symbol of the first return type of a proc
resolve_symbol_proc_first_return_symbol :: proc(ast_context: ^AstContext, symbol: Symbol) -> (Symbol, bool) {
if v, ok := symbol.value.(SymbolProcedureValue); ok {
diff --git a/src/server/collector.odin b/src/server/collector.odin
index 0f85774..a7785fa 100644
--- a/src/server/collector.odin
+++ b/src/server/collector.odin
@@ -1036,9 +1036,9 @@ get_package_mapping :: proc(file: ast.File, config: ^common.Config, directory: s
package_map[name] = full
} else {
name: string
-
+ pkg_name := imp.fullpath[1:len(imp.fullpath) - 1]
full := path.join(
- elems = {directory, imp.fullpath[1:len(imp.fullpath) - 1]},
+ elems = {directory, pkg_name},
allocator = context.temp_allocator,
)
full = path.clean(full, context.temp_allocator)
@@ -1048,6 +1048,12 @@ get_package_mapping :: proc(file: ast.File, config: ^common.Config, directory: s
} else {
name = path.base(full, false, context.temp_allocator)
}
+ // Check if the package already exists in the index and use that path
+ // This handles the case where packages are indexed separately (e.g., in tests)
+ test_path := path.join(elems = {"test", pkg_name}, allocator = context.temp_allocator)
+ if _, exists := indexer.index.collection.packages[test_path]; exists {
+ full = test_path
+ }
package_map[name] = full
}
@@ -1098,6 +1104,12 @@ replace_package_alias_node :: proc(node: ^ast.Node, package_map: map[string]stri
#partial switch n in node.derived {
case ^Bad_Expr:
case ^Ident:
+ // Replace stand-alone identifiers that are package aliases
+ if package_name, ok := package_map[n.name]; ok {
+ n.name = get_index_unique_string(collection, package_name)
+ } else if strings.contains(n.name, "/") {
+ n.name = get_index_unique_string(collection, n.name)
+ }
case ^Implicit:
case ^Undef:
case ^Basic_Lit:
diff --git a/src/server/generics.odin b/src/server/generics.odin
index 347f484..71a0da7 100644
--- a/src/server/generics.odin
+++ b/src/server/generics.odin
@@ -39,6 +39,8 @@ resolve_poly :: proc(
if ident, ok := unwrap_ident(type); ok {
if untyped_value, ok := call_symbol.value.(SymbolUntypedValue); ok {
save_poly_map(ident, symbol_to_expr(call_symbol, call_node.pos.file), poly_map)
+ } else if .Anonymous in call_symbol.flags {
+ save_poly_map(ident, call_node, poly_map)
} else {
save_poly_map(
ident,
@@ -515,6 +517,7 @@ resolve_generic_function_symbol :: proc(
if file == "" {
file = call_expr.args[i].pos.file
}
+
symbol_expr := symbol_to_expr(symbol, file, context.temp_allocator)
if symbol_expr == nil {
@@ -525,17 +528,13 @@ resolve_generic_function_symbol :: proc(
if symbol_value, ok := symbol.value.(SymbolProcedureValue); ok && len(symbol_value.return_types) > 0 {
call_arg_count += get_proc_return_value_count(symbol_value.return_types)
if _, ok := call_expr.args[i].derived.(^ast.Call_Expr); ok {
- if symbol_value.return_types[0].type != nil {
- if symbol, ok = resolve_type_expression(ast_context, symbol_value.return_types[0].type);
- ok {
- symbol_expr = symbol_to_expr(
- symbol,
- call_expr.args[i].pos.file,
- context.temp_allocator,
- )
- if symbol_expr == nil {
- return {}, false
- }
+ ret_type := symbol_value.return_types[0].type
+ if ret_type == nil {
+ ret_type = symbol_value.return_types[0].default_value
+ }
+ if ret_type != nil {
+ if symbol, ok = resolve_type_expression(ast_context, ret_type); ok {
+ symbol_expr = ret_type
}
}
}
diff --git a/src/server/references.odin b/src/server/references.odin
index aeebd43..ddd5f17 100644
--- a/src/server/references.odin
+++ b/src/server/references.odin
@@ -14,29 +14,6 @@ import "core:strings"
import "src:common"
-fullpaths: [dynamic]string
-
-walk_directories :: proc(info: os.File_Info, in_err: os.Error, user_data: rawptr) -> (err: os.Error, skip_dir: bool) {
- document := cast(^Document)user_data
-
- if info.type == .Directory {
- return nil, false
- }
-
- if info.fullpath == "" {
- return nil, false
- }
-
- if strings.contains(info.name, ".odin") {
- slash_path, _ := filepath.replace_path_separators(info.fullpath, '/', context.temp_allocator)
- if slash_path != document.fullpath {
- append(&fullpaths, strings.clone(info.fullpath, context.temp_allocator))
- }
- }
-
- return nil, false
-}
-
prepare_references :: proc(
document: ^Document,
ast_context: ^AstContext,
@@ -236,12 +213,13 @@ resolve_references :: proc(
ast_context: ^AstContext,
position_context: ^DocumentPositionContext,
current_file_only := false,
+ include_declaration := true,
) -> (
[]common.Location,
bool,
) {
locations := make([dynamic]common.Location, 0, ast_context.allocator)
- fullpaths = make([dynamic]string, 0, ast_context.allocator)
+ fullpaths := make([dynamic]string, 0, ast_context.allocator)
symbol, resolve_flag, ok := prepare_references(document, ast_context, position_context)
@@ -253,9 +231,13 @@ resolve_references :: proc(
for k, v in symbols_and_nodes {
if strings.equal_fold(v.symbol.uri, symbol.uri) && v.symbol.range == symbol.range {
node_uri := common.create_uri(v.node.pos.file, ast_context.allocator)
-
range := common.get_token_range(v.node^, ast_context.file.src)
+ if !include_declaration && v.symbol.range == range && strings.equal_fold(node_uri.uri, symbol.uri) {
+ // This is the declaration and so we skip it
+ continue
+ }
+
//We don't have to have the `.` with, otherwise it renames the dot.
if _, ok := v.node.derived.(^ast.Implicit_Selector_Expr); ok {
range.start.character += 1
@@ -309,9 +291,9 @@ resolve_references :: proc(
context.allocator = runtime.arena_allocator(&arena)
- fullpaths := slice.unique(fullpaths[:])
+ paths := slice.unique(fullpaths[:])
- for fullpath in fullpaths {
+ for fullpath in paths {
dir := filepath.dir(fullpath)
base := filepath.base(dir)
forward_dir, _ := filepath.replace_path_separators(dir, '/', context.allocator)
@@ -383,6 +365,13 @@ resolve_references :: proc(
if strings.equal_fold(v.symbol.uri, symbol.uri) && v.symbol.range == symbol.range {
node_uri := common.create_uri(v.node.pos.file, ast_context.allocator)
range := common.get_token_range(v.node^, string(document.text))
+
+ if !include_declaration &&
+ v.symbol.range == range &&
+ strings.equal_fold(node_uri.uri, symbol.uri) {
+ // This is the declaration and so we skip it
+ continue
+ }
//We don't have to have the `.` with, otherwise it renames the dot.
if _, ok := v.node.derived.(^ast.Implicit_Selector_Expr); ok {
range.start.character += 1
@@ -406,6 +395,7 @@ get_references :: proc(
document: ^Document,
position: common.Position,
current_file_only := false,
+ include_declaration := true,
) -> (
[]common.Location,
bool,
@@ -435,7 +425,13 @@ get_references :: proc(
get_locals(document.ast, position_context.function, &ast_context, &position_context)
}
- locations, ok2 := resolve_references(document, &ast_context, &position_context, current_file_only)
+ locations, ok2 := resolve_references(
+ document,
+ &ast_context,
+ &position_context,
+ current_file_only,
+ include_declaration = include_declaration,
+ )
temp_locations := make([dynamic]common.Location, 0, context.temp_allocator)
diff --git a/src/server/requests.odin b/src/server/requests.odin
index 3aec455..57d190d 100644
--- a/src/server/requests.odin
+++ b/src/server/requests.odin
@@ -214,12 +214,7 @@ read_and_parse_body :: proc(reader: ^Reader, header: Header) -> (json.Value, boo
return value, true
}
-call_map: map[string]proc(
- _: json.Value,
- _: RequestId,
- _: ^common.Config,
- _: ^Writer,
-) -> common.Error = {
+call_map: map[string]proc(_: json.Value, _: RequestId, _: ^common.Config, _: ^Writer) -> common.Error = {
"initialize" = request_initialize,
"initialized" = request_initialized,
"shutdown" = request_shutdown,
@@ -331,10 +326,7 @@ call :: proc(value: json.Value, id: RequestId, writer: ^Writer, config: ^common.
if !ok {
log.errorf("Failed to find method: %#v", root)
- response := make_response_message_error(
- id = id,
- error = ResponseError{code = .MethodNotFound, message = ""},
- )
+ response := make_response_message_error(id = id, error = ResponseError{code = .MethodNotFound, message = ""})
send_error(response, writer)
return
}
@@ -352,10 +344,7 @@ call :: proc(value: json.Value, id: RequestId, writer: ^Writer, config: ^common.
} else {
err := fn(root["params"], id, config, writer)
if err != .None {
- response := make_response_message_error(
- id = id,
- error = ResponseError{code = err, message = ""},
- )
+ response := make_response_message_error(id = id, error = ResponseError{code = err, message = ""})
send_error(response, writer)
}
}
@@ -364,20 +353,13 @@ call :: proc(value: json.Value, id: RequestId, writer: ^Writer, config: ^common.
//log.errorf("time duration %v for %v", time.duration_milliseconds(diff), method)
}
-read_ols_initialize_options :: proc(
- config: ^common.Config,
- ols_config: OlsConfig,
- uri: common.Uri,
-) {
- config.disable_parser_errors =
- ols_config.disable_parser_errors.(bool) or_else config.disable_parser_errors
+read_ols_initialize_options :: proc(config: ^common.Config, ols_config: OlsConfig, uri: common.Uri) {
+ config.disable_parser_errors = ols_config.disable_parser_errors.(bool) or_else config.disable_parser_errors
config.thread_count = ols_config.thread_pool_count.(int) or_else config.thread_count
- config.enable_document_symbols =
- ols_config.enable_document_symbols.(bool) or_else config.enable_document_symbols
+ config.enable_document_symbols = ols_config.enable_document_symbols.(bool) or_else config.enable_document_symbols
config.enable_format = ols_config.enable_format.(bool) or_else config.enable_format
config.enable_hover = ols_config.enable_hover.(bool) or_else config.enable_hover
- config.enable_semantic_tokens =
- ols_config.enable_semantic_tokens.(bool) or_else config.enable_semantic_tokens
+ config.enable_semantic_tokens = ols_config.enable_semantic_tokens.(bool) or_else config.enable_semantic_tokens
config.enable_unused_imports_reporting =
ols_config.enable_unused_imports_reporting.(bool) or_else config.enable_unused_imports_reporting
config.enable_procedure_context =
@@ -388,20 +370,20 @@ read_ols_initialize_options :: proc(
ols_config.enable_document_highlights.(bool) or_else config.enable_document_highlights
config.enable_completion_matching =
ols_config.enable_completion_matching.(bool) or_else config.enable_completion_matching
- config.enable_document_links =
- ols_config.enable_document_links.(bool) or_else config.enable_document_links
+ config.enable_document_links = ols_config.enable_document_links.(bool) or_else config.enable_document_links
config.enable_comp_lit_signature_help =
ols_config.enable_comp_lit_signature_help.(bool) or_else config.enable_comp_lit_signature_help
config.enable_comp_lit_signature_help_use_docs =
ols_config.enable_comp_lit_signature_help_use_docs.(bool) or_else config.enable_comp_lit_signature_help_use_docs
+ config.enable_code_action_invert_if =
+ ols_config.enable_code_action_invert_if.(bool) or_else config.enable_code_action_invert_if
config.verbose = ols_config.verbose.(bool) or_else config.verbose
config.file_log = ols_config.file_log.(bool) or_else config.file_log
config.enable_procedure_snippet =
ols_config.enable_procedure_snippet.(bool) or_else config.enable_procedure_snippet
- config.enable_auto_import =
- ols_config.enable_auto_import.(bool) or_else config.enable_auto_import
+ config.enable_auto_import = ols_config.enable_auto_import.(bool) or_else config.enable_auto_import
config.enable_checker_only_saved =
ols_config.enable_checker_only_saved.(bool) or_else config.enable_checker_only_saved
@@ -417,10 +399,7 @@ read_ols_initialize_options :: proc(
}
if ols_config.odin_root_override != "" {
- config.odin_root_override = strings.clone(
- ols_config.odin_root_override,
- context.temp_allocator,
- )
+ config.odin_root_override = strings.clone(ols_config.odin_root_override, context.temp_allocator)
allocated: bool
config.odin_root_override, allocated = common.resolve_home_dir(config.odin_root_override)
@@ -473,8 +452,7 @@ read_ols_initialize_options :: proc(
config.enable_inlay_hints_implicit_return =
ols_config.enable_inlay_hints_implicit_return.(bool) or_else config.enable_inlay_hints_implicit_return
- config.enable_fake_method =
- ols_config.enable_fake_methods.(bool) or_else config.enable_fake_method
+ config.enable_fake_method = ols_config.enable_fake_methods.(bool) or_else config.enable_fake_method
config.enable_overload_resolution =
ols_config.enable_overload_resolution.(bool) or_else config.enable_overload_resolution
@@ -512,10 +490,7 @@ read_ols_initialize_options :: proc(
} else {
final_path, _ = filepath.replace_path_separators(
common.get_case_sensitive_path(
- path.join(
- elems = {uri.path, forward_path},
- allocator = context.temp_allocator,
- ),
+ path.join(elems = {uri.path, forward_path}, allocator = context.temp_allocator),
context.temp_allocator,
),
'/',
@@ -537,11 +512,7 @@ read_ols_initialize_options :: proc(
log.errorf("Failed to find absolute address of collection: %v", final_path, err)
config.collections[strings.clone(it.name)] = strings.clone(final_path)
} else {
- slashed_path, _ := filepath.replace_path_separators(
- abs_final_path,
- '/',
- context.temp_allocator,
- )
+ slashed_path, _ := filepath.replace_path_separators(abs_final_path, '/', context.temp_allocator)
config.collections[strings.clone(it.name)] = strings.clone(slashed_path)
}
@@ -587,8 +558,7 @@ read_ols_initialize_options :: proc(
}
if odin_core_env != "" {
- if abs_core_env, err := filepath.abs(odin_core_env, context.temp_allocator);
- err == nil {
+ if abs_core_env, err := filepath.abs(odin_core_env, context.temp_allocator); err == nil {
odin_core_env = abs_core_env
}
}
@@ -599,11 +569,7 @@ read_ols_initialize_options :: proc(
// Insert the default collections if they are not specified in the config.
if odin_core_env != "" {
- forward_path, _ := filepath.replace_path_separators(
- odin_core_env,
- '/',
- context.temp_allocator,
- )
+ forward_path, _ := filepath.replace_path_separators(odin_core_env, '/', context.temp_allocator)
// base
if "base" not_in config.collections {
@@ -631,10 +597,7 @@ read_ols_initialize_options :: proc(
// shared
if "shared" not_in config.collections {
- shared_path := path.join(
- elems = {forward_path, "shared"},
- allocator = context.allocator,
- )
+ shared_path := path.join(elems = {forward_path, "shared"}, allocator = context.allocator)
if os.exists(shared_path) {
config.collections[strings.clone("shared")] = shared_path
} else {
@@ -739,10 +702,7 @@ request_initialize :: proc(
read_ols_initialize_options(config, initialize_params.initializationOptions, uri)
// Apply ols.json config.
- ols_config_path := path.join(
- elems = {uri.path, "ols.json"},
- allocator = context.temp_allocator,
- )
+ ols_config_path := path.join(elems = {uri.path, "ols.json"}, allocator = context.temp_allocator)
read_ols_config(ols_config_path, config, uri)
} else {
read_ols_initialize_options(config, initialize_params.initializationOptions, {})
@@ -763,8 +723,7 @@ request_initialize :: proc(
config.enable_label_details =
initialize_params.capabilities.textDocument.completion.completionItem.labelDetailsSupport
- config.enable_snippets &=
- initialize_params.capabilities.textDocument.completion.completionItem.snippetSupport
+ config.enable_snippets &= initialize_params.capabilities.textDocument.completion.completionItem.snippetSupport
config.signature_offset_support =
initialize_params.capabilities.textDocument.signatureHelp.signatureInformation.parameterInformation.labelOffsetSupport
@@ -773,17 +732,12 @@ request_initialize :: proc(
signatureTriggerCharacters := []string{"(", ","}
signatureRetriggerCharacters := []string{","}
- semantic_range_support :=
- initialize_params.capabilities.textDocument.semanticTokens.requests.range
+ semantic_range_support := initialize_params.capabilities.textDocument.semanticTokens.requests.range
response := make_response_message(
params = ResponseInitializeParams {
capabilities = ServerCapabilities {
- textDocumentSync = TextDocumentSyncOptions {
- openClose = true,
- change = 2,
- save = {includeText = true},
- },
+ textDocumentSync = TextDocumentSyncOptions{openClose = true, change = 2, save = {includeText = true}},
renameProvider = RenameOptions{prepareProvider = true},
workspaceSymbolProvider = true,
referencesProvider = config.enable_references,
@@ -814,10 +768,7 @@ request_initialize :: proc(
hoverProvider = config.enable_hover,
documentFormattingProvider = config.enable_format,
documentLinkProvider = {resolveProvider = false},
- codeActionProvider = {
- resolveProvider = false,
- codeActionKinds = {"refactor.rewrite"},
- },
+ codeActionProvider = {resolveProvider = false, codeActionKinds = {"refactor.rewrite"}},
},
},
id = id,
@@ -883,12 +834,7 @@ request_initialized :: proc(
return .None
}
-request_shutdown :: proc(
- params: json.Value,
- id: RequestId,
- config: ^common.Config,
- writer: ^Writer,
-) -> common.Error {
+request_shutdown :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error {
response := make_response_message(params = nil, id = id)
send_response(response, writer)
@@ -1003,12 +949,7 @@ request_completion :: proc(
}
list: CompletionList
- list, ok = get_completion_list(
- document,
- completition_params.position,
- completition_params.context_,
- config,
- )
+ list, ok = get_completion_list(document, completition_params.position, completition_params.context_, config)
if !ok {
return .InternalError
@@ -1102,12 +1043,7 @@ request_format_document :: proc(
return .None
}
-notification_exit :: proc(
- params: json.Value,
- id: RequestId,
- config: ^common.Config,
- writer: ^Writer,
-) -> common.Error {
+notification_exit :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error {
config.running = false
return .None
}
@@ -1134,12 +1070,7 @@ notification_did_open :: proc(
defer delete(open_params.textDocument.uri)
- if n := document_open(
- open_params.textDocument.uri,
- open_params.textDocument.text,
- config,
- writer,
- ); n != .None {
+ if n := document_open(open_params.textDocument.uri, open_params.textDocument.text, config, writer); n != .None {
return .InternalError
}
@@ -1375,12 +1306,7 @@ request_document_symbols :: proc(
return .None
}
-request_hover :: proc(
- params: json.Value,
- id: RequestId,
- config: ^common.Config,
- writer: ^Writer,
-) -> common.Error {
+request_hover :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error {
params_object, ok := params.(json.Object)
if !ok {
@@ -1528,12 +1454,7 @@ request_prepare_rename :: proc(
return .None
}
-request_rename :: proc(
- params: json.Value,
- id: RequestId,
- config: ^common.Config,
- writer: ^Writer,
-) -> common.Error {
+request_rename :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error {
params_object, ok := params.(json.Object)
if !ok {
@@ -1580,7 +1501,13 @@ request_references :: proc(
reference_param: ReferenceParams
- if unmarshal(params, reference_param, context.temp_allocator) != nil {
+ // Due to the field named `context`, we need to use json tags and this is the easiest way to handle that right now.
+ data, err := json.marshal(params_object)
+ if err != nil {
+ return .ParseError
+ }
+
+ if err := json.unmarshal(data, &reference_param, allocator = context.temp_allocator); err != nil {
return .ParseError
}
@@ -1591,7 +1518,11 @@ request_references :: proc(
}
locations: []common.Location
- locations, ok = get_references(document, reference_param.position)
+ locations, ok = get_references(
+ document,
+ reference_param.position,
+ include_declaration = reference_param.ctx.includeDeclaration,
+ )
if !ok {
return .InternalError
@@ -1783,11 +1714,6 @@ request_workspace_symbols :: proc(
return .None
}
-request_noop :: proc(
- params: json.Value,
- id: RequestId,
- config: ^common.Config,
- writer: ^Writer,
-) -> common.Error {
+request_noop :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error {
return .None
}
diff --git a/src/server/types.odin b/src/server/types.odin
index c65e181..fc52119 100644
--- a/src/server/types.odin
+++ b/src/server/types.odin
@@ -436,6 +436,7 @@ OlsConfig :: struct {
enable_procedure_snippet: Maybe(bool),
enable_checker_only_saved: Maybe(bool),
enable_auto_import: Maybe(bool),
+ enable_code_action_invert_if: Maybe(bool),
disable_parser_errors: Maybe(bool),
verbose: Maybe(bool),
file_log: Maybe(bool),
@@ -561,9 +562,14 @@ PrepareRenameParams :: struct {
position: common.Position,
}
+ReferenceContext :: struct {
+ includeDeclaration: bool,
+}
+
ReferenceParams :: struct {
textDocument: TextDocumentIdentifier,
position: common.Position,
+ ctx: ReferenceContext `json:"context"`,
}
HighlightParams :: struct {
diff --git a/src/testing/testing.odin b/src/testing/testing.odin
index 5e927a7..46c597f 100644
--- a/src/testing/testing.odin
+++ b/src/testing/testing.odin
@@ -464,11 +464,12 @@ expect_reference_locations :: proc(
src: ^Source,
expect_locations: []common.Location,
expect_excluded: []common.Location = nil,
+ include_declaration := true,
) {
setup(src)
defer teardown(src)
- locations, ok := server.get_references(src.document, src.position)
+ locations, ok := server.get_references(src.document, src.position, include_declaration = include_declaration)
for expect_location in expect_locations {
match := false