diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/common/config.odin | 1 | ||||
| -rw-r--r-- | src/server/action.odin | 5 | ||||
| -rw-r--r-- | src/server/action_invert_if_statements.odin | 2 | ||||
| -rw-r--r-- | src/server/analysis.odin | 349 | ||||
| -rw-r--r-- | src/server/collector.odin | 16 | ||||
| -rw-r--r-- | src/server/generics.odin | 21 | ||||
| -rw-r--r-- | src/server/references.odin | 52 | ||||
| -rw-r--r-- | src/server/requests.odin | 158 | ||||
| -rw-r--r-- | src/server/types.odin | 6 | ||||
| -rw-r--r-- | src/testing/testing.odin | 3 |
10 files changed, 358 insertions, 255 deletions
diff --git a/src/common/config.odin b/src/common/config.odin index f2ae68f..d1bfd31 100644 --- a/src/common/config.odin +++ b/src/common/config.odin @@ -41,6 +41,7 @@ Config :: struct { enable_document_links: bool, enable_comp_lit_signature_help: bool, enable_comp_lit_signature_help_use_docs: bool, + enable_code_action_invert_if: bool, disable_parser_errors: bool, thread_count: int, file_log: bool, diff --git a/src/server/action.odin b/src/server/action.odin index 11e8f86..ce0dbb3 100644 --- a/src/server/action.odin +++ b/src/server/action.odin @@ -79,7 +79,10 @@ get_code_actions :: proc(document: ^Document, range: common.Range, config: ^comm &actions, ) } - add_invert_if_action(document, position_context.position, strings.clone(document.uri.uri), &actions) + + if config.enable_code_action_invert_if { + add_invert_if_action(document, position_context.position, strings.clone(document.uri.uri), &actions) + } return actions[:], true } diff --git a/src/server/action_invert_if_statements.odin b/src/server/action_invert_if_statements.odin index 047b3a2..336caaa 100644 --- a/src/server/action_invert_if_statements.odin +++ b/src/server/action_invert_if_statements.odin @@ -3,10 +3,8 @@ package server import "core:fmt" -import "core:log" import "core:odin/ast" import "core:odin/tokenizer" -import path "core:path/slashpath" import "core:strings" import "src:common" diff --git a/src/server/analysis.odin b/src/server/analysis.odin index 1a85708..e980913 100644 --- a/src/server/analysis.odin +++ b/src/server/analysis.odin @@ -697,6 +697,108 @@ should_resolve_all_proc_overload_possibilities :: proc(ast_context: ^AstContext, return ast_context.position_hint == .Completion || ast_context.position_hint == .SignatureHelp || call_expr == nil } +CallArg :: struct { + symbol: Symbol, + implicit_selector: ^ast.Implicit_Selector_Expr, + name: string, + named: bool, + is_nil: bool, + bad_expr: bool, + has_symbol: bool, + is_poly_type: bool, +} + +expand_call_args :: proc(ast_context: ^AstContext, call: ^ast.Call_Expr) -> ([]CallArg, bool) { + results := make([dynamic]CallArg, context.temp_allocator) + if call == nil { + return results[:], true + } + + used_named := false + append_arg :: proc( + ast_context: ^AstContext, + arg: ^ast.Expr, + results: ^[dynamic]CallArg, + used_named: ^bool, + ) -> bool { + ast_context.use_locals = true + + call_arg := CallArg{} + + if _, ok := arg.derived.(^ast.Bad_Expr); ok { + call_arg.bad_expr = true + append(results, call_arg) + return true + } + + value_expr := arg + + //named parameter + if field, ok := arg.derived.(^ast.Field_Value); ok { + call_arg.named = true + value_expr = field.value + used_named^ = true + + if ident, ok := field.field.derived.(^ast.Ident); ok { + call_arg.name = ident.name + } + } else if used_named^ { + log.error("Expected name parameter after starting named parmeter phase") + return false + } + + if ident, ok := value_expr.derived.(^ast.Ident); ok && ident.name == "nil" { + call_arg.is_nil = true + append(results, call_arg) + return true + } else if implicit, ok := value_expr.derived.(^ast.Implicit_Selector_Expr); ok { + call_arg.implicit_selector = implicit + append(results, call_arg) + return true + } + + if symbol, ok := resolve_call_arg_type_expression(ast_context, value_expr); ok { + call_arg.symbol = symbol + call_arg.has_symbol = true + if _, ok := symbol.value.(SymbolPolyTypeValue); ok { + call_arg.is_poly_type = true + append(results, call_arg) + return true + } else if v, ok := symbol.value.(SymbolProcedureValue); ok { + if len(v.return_types) == 0 { + return false + } + for arg in v.return_types { + expr := arg.type + if expr == nil { + expr = arg.default_value + } + + if !append_arg(ast_context, expr, results, used_named) { + return false + } + } + return true + } else { + append(results, call_arg) + return true + } + } else { + return false + } + + return true + } + + for arg in call.args { + if !append_arg(ast_context, arg, &results, &used_named) { + return {}, false + } + } + + return results[:], true +} + /* Figure out which function the call expression is using out of the list from proc group */ @@ -721,12 +823,21 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro } resolve_all_possibilities := should_resolve_all_proc_overload_possibilities(ast_context, call_expr) - call_unnamed_arg_count := 0 - if call_expr != nil { - call_unnamed_arg_count = get_unnamed_arg_count(call_expr.args) - } candidates := make([dynamic]Candidate, context.temp_allocator) + call_args, ok := expand_call_args(ast_context, call_expr) + if !ok { + return {}, false + } + + if !resolve_all_possibilities { + for arg in call_args { + if arg.is_poly_type { + resolve_all_possibilities = true + break + } + } + } for arg_expr in group.args { f := Symbol{} @@ -735,110 +846,52 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro symbol = f, score = 1, } - if call_expr == nil || (resolve_all_possibilities && len(call_expr.args) == 0) { + if call_expr == nil || (resolve_all_possibilities && len(call_args) == 0) { append(&candidates, candidate) break next_fn } if procedure, ok := f.value.(SymbolProcedureValue); ok { i := 0 - named := false if !resolve_all_possibilities { arg_count := get_proc_arg_count(procedure) - if call_expr != nil && arg_count < len(call_expr.args) { + if call_expr != nil && arg_count < len(call_args) { break next_fn } - if arg_count == len(call_expr.args) { + if arg_count == len(call_args) { candidate.score /= 2 } } for proc_arg in procedure.arg_types { for name in proc_arg.names { - if i >= len(call_expr.args) { + if i >= len(call_args) { + i += 1 continue } - call_arg := call_expr.args[i] + call_arg := call_args[i] + i += 1 ast_context.use_locals = true - call_symbol: Symbol arg_symbol: Symbol ok: bool - is_call_arg_nil: bool - implicit_selector: ^ast.Implicit_Selector_Expr - if _, ok = call_arg.derived.(^ast.Bad_Expr); ok { + if call_arg.bad_expr { continue } - //named parameter - if field, is_field := call_arg.derived.(^ast.Field_Value); is_field { - named = true - if ident, is_ident := field.value.derived.(^ast.Ident); is_ident && ident.name == "nil" { - is_call_arg_nil = true - ok = true - } else if implicit, is_implicit := field.value.derived.(^ast.Implicit_Selector_Expr); - is_implicit { - implicit_selector = implicit - ok = true - } else { - call_symbol, ok = resolve_call_arg_type_expression(ast_context, field.value) - if !ok { - break next_fn - } - } - - if ident, is_ident := field.field.derived.(^ast.Ident); is_ident { - i, ok = get_field_list_name_index( - field.field.derived.(^ast.Ident).name, - procedure.arg_types, - ) - } else { - break next_fn - } - } else { - if named { - log.error("Expected name parameter after starting named parmeter phase") - return {}, false - } - if ident, is_ident := call_arg.derived.(^ast.Ident); is_ident && ident.name == "nil" { - is_call_arg_nil = true - ok = true - } else if implicit, is_implicit_selector := call_arg.derived.(^ast.Implicit_Selector_Expr); - is_implicit_selector { - implicit_selector = implicit - ok = true - } else { - call_symbol, ok = resolve_call_arg_type_expression(ast_context, call_arg) - } - } - - if !ok { - break next_fn - } - - - if p, ok := call_symbol.value.(SymbolProcedureValue); ok { - if len(p.return_types) != 1 { - break next_fn - } - if s, ok := resolve_call_arg_type_expression(ast_context, p.return_types[0].type); ok { - call_symbol = s - } - } - - // If an arg is a parapoly type, we assume it can match any symbol and return all possible - // matches - if _, ok := call_symbol.value.(SymbolPolyTypeValue); ok { - resolve_all_possibilities = true + if call_arg.is_poly_type { continue } proc_arg := proc_arg - if named { - proc_arg = procedure.arg_types[i] + if call_arg.named { + proc_arg, ok = get_proc_arg_type_from_name(procedure, call_arg.name) + if !ok { + break next_fn + } } if proc_arg.type != nil { @@ -851,11 +904,17 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro break next_fn } - if implicit_selector != nil { + // TODO: check intrinsics for parapoly types? + if _, is_poly := arg_symbol.value.(SymbolPolyTypeValue); is_poly { + candidate.score += 1 + continue + } + + if call_arg.implicit_selector != nil { if value, ok := arg_symbol.value.(SymbolEnumValue); ok { found: bool for name in value.names { - if implicit_selector.field.name == name { + if call_arg.implicit_selector.field.name == name { found = true break } @@ -863,27 +922,25 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro if found { continue } - } break next_fn } - if is_call_arg_nil { + if call_arg.is_nil { if is_valid_nil_symbol(arg_symbol) { continue } else { break next_fn } - } - if !is_symbol_same_typed(ast_context, call_symbol, arg_symbol, proc_arg.flags) { + if !is_symbol_same_typed(ast_context, call_arg.symbol, arg_symbol, proc_arg.flags) { found := false // Are we a union variant if value, ok := arg_symbol.value.(SymbolUnionValue); ok { for variant in value.types { if symbol, ok := resolve_type_expression(ast_context, variant); ok { - if is_symbol_same_typed(ast_context, call_symbol, symbol, proc_arg.flags) { + if is_symbol_same_typed(ast_context, call_arg.symbol, symbol, proc_arg.flags) { // matching union types are a low priority candidate.score = 1000000 found = true @@ -894,11 +951,11 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro } // Do we contain a using that matches - if value, ok := call_symbol.value.(SymbolStructValue); ok { + if value, ok := call_arg.symbol.value.(SymbolStructValue); ok { using_score := 1000000 for k in value.usings { if symbol, ok := resolve_type_expression(ast_context, value.types[k]); ok { - symbol.pointers = call_symbol.pointers + symbol.pointers = call_arg.symbol.pointers if is_symbol_same_typed(ast_context, symbol, arg_symbol, proc_arg.flags) { if k < using_score { using_score = k @@ -914,8 +971,6 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro break next_fn } } - - i += 1 } } @@ -948,14 +1003,14 @@ resolve_function_overload :: proc(ast_context: ^AstContext, group: ^ast.Proc_Gro return {}, false } - symbol, ok := get_candidate_symbol(candidates[:], resolve_all_possibilities) + symbol, ok_canidate := get_candidate_symbol(candidates[:], resolve_all_possibilities) if call_expr != nil { ast_context.call_expr_recursion_cache[cast(rawptr)call_expr] = SymbolResult { symbol = symbol, - ok = ok, + ok = ok_canidate, } } - return symbol, ok + return symbol, ok_canidate } resolve_call_arg_type_expression :: proc(ast_context: ^AstContext, node: ^ast.Expr) -> (Symbol, bool) { @@ -1441,20 +1496,20 @@ resolve_index_expr :: proc(ast_context: ^AstContext, index_expr: ^ast.Index_Expr #partial switch v in indexed.value { case SymbolDynamicArrayValue: if .Soa in indexed.flags { - indexed.flags |= { .SoaPointer } + indexed.flags |= {.SoaPointer} return indexed, true } ok = internal_resolve_type_expression(ast_context, v.expr, &symbol) case SymbolSliceValue: ok = internal_resolve_type_expression(ast_context, v.expr, &symbol) if .Soa in indexed.flags { - indexed.flags |= { .SoaPointer } + indexed.flags |= {.SoaPointer} return indexed, true } case SymbolFixedArrayValue: ok = internal_resolve_type_expression(ast_context, v.expr, &symbol) if .Soa in indexed.flags { - indexed.flags |= { .SoaPointer } + indexed.flags |= {.SoaPointer} return indexed, true } case SymbolMapValue: @@ -1688,10 +1743,31 @@ resolve_selector_expression :: proc(ast_context: ^AstContext, node: ^ast.Selecto try_build_package(ast_context.current_package) if node.field != nil { - return resolve_symbol_return(ast_context, lookup(node.field.name, selector.pkg, node.pos.file)) + field_symbol, ok := lookup(node.field.name, selector.pkg, node.pos.file) + if ok { + if pkg_alias_symbol, ok := resolve_field_through_package_alias(ast_context, field_symbol, selector.pkg); ok { + return pkg_alias_symbol, true + } + return resolve_symbol_return(ast_context, field_symbol) + } + return {}, false } else { return Symbol{}, false } + case SymbolBasicValue: + if s.ident != nil && node.field != nil { + if symbol, ok := resolve_field_access_through_imported_alias(ast_context, s.ident, node); ok { + return symbol, true + } + } + case SymbolGenericValue: + if s.expr != nil { + if ident, ok := s.expr.derived.(^ast.Ident); ok && node.field != nil { + if symbol, ok := resolve_field_access_through_imported_alias(ast_context, ident, node); ok { + return symbol, true + } + } + } case SymbolEnumValue: // enum members probably require own symbol value selector.type = .EnumMember @@ -1713,6 +1789,91 @@ resolve_selector_expression :: proc(ast_context: ^AstContext, node: ^ast.Selecto return {}, false } +is_path_package_name :: #force_inline proc(name: string) -> bool { + return strings.contains(name, "/") +} + +get_package_ident_from_symbol :: proc(symbol: Symbol) -> (ident: ^ast.Ident, ok: bool) { + #partial switch v in symbol.value { + case SymbolBasicValue: + if v.ident != nil { + return v.ident, true + } + case SymbolGenericValue: + if v.expr != nil { + if ident, ok := v.expr.derived.(^ast.Ident); ok { + return ident, true + } + } + } + return nil, false +} + +resolve_ident_as_package :: proc(ast_context: ^AstContext, ident: ^ast.Ident, context_pkg: string) -> (Symbol, bool) { + ident_pkg, pkg_ok := internal_resolve_type_identifier(ast_context, ident^) + if pkg_ok && ident_pkg.type == .Package { + return ident_pkg, true + } + return {}, false +} + +resolve_field_through_package_alias :: proc(ast_context: ^AstContext, field_symbol: Symbol, context_pkg: string) -> (Symbol, bool) { + ident, ok := get_package_ident_from_symbol(field_symbol) + if !ok { + return {}, false + } + + if is_path_package_name(ident.name) { + return Symbol{ + type = .Package, + pkg = ident.name, + value = SymbolPackageValue{}, + }, true + } + + current_package := ast_context.current_package + defer { + ast_context.current_package = current_package + } + ast_context.current_package = context_pkg + + pkg_symbol, pkg_ok := resolve_ident_as_package(ast_context, ident, context_pkg) + if pkg_ok && pkg_symbol.type == .Package { + return pkg_symbol, true + } + + return {}, false +} + +resolve_field_access_through_imported_alias :: proc(ast_context: ^AstContext, ident: ^ast.Ident, node: ^ast.Selector_Expr) -> (Symbol, bool) { + for imp in ast_context.imports { + if strings.compare(imp.base, ident.name) == 0 { + try_build_package(ast_context.current_package) + if node.field != nil { + symbol, ok := lookup(node.field.name, imp.name, node.pos.file) + if ok { + return resolve_symbol_return(ast_context, symbol) + } + } + } + } + + pkg_symbol, pkg_ok := internal_resolve_type_identifier(ast_context, ident^) + if pkg_ok { + if _, ok2 := pkg_symbol.value.(SymbolPackageValue); ok2 { + try_build_package(ast_context.current_package) + if node.field != nil { + symbol, ok := lookup(node.field.name, pkg_symbol.pkg, node.pos.file) + if ok { + return resolve_symbol_return(ast_context, symbol) + } + } + } + } + + return {}, false +} + // returns the symbol of the first return type of a proc resolve_symbol_proc_first_return_symbol :: proc(ast_context: ^AstContext, symbol: Symbol) -> (Symbol, bool) { if v, ok := symbol.value.(SymbolProcedureValue); ok { diff --git a/src/server/collector.odin b/src/server/collector.odin index 0f85774..a7785fa 100644 --- a/src/server/collector.odin +++ b/src/server/collector.odin @@ -1036,9 +1036,9 @@ get_package_mapping :: proc(file: ast.File, config: ^common.Config, directory: s package_map[name] = full } else { name: string - + pkg_name := imp.fullpath[1:len(imp.fullpath) - 1] full := path.join( - elems = {directory, imp.fullpath[1:len(imp.fullpath) - 1]}, + elems = {directory, pkg_name}, allocator = context.temp_allocator, ) full = path.clean(full, context.temp_allocator) @@ -1048,6 +1048,12 @@ get_package_mapping :: proc(file: ast.File, config: ^common.Config, directory: s } else { name = path.base(full, false, context.temp_allocator) } + // Check if the package already exists in the index and use that path + // This handles the case where packages are indexed separately (e.g., in tests) + test_path := path.join(elems = {"test", pkg_name}, allocator = context.temp_allocator) + if _, exists := indexer.index.collection.packages[test_path]; exists { + full = test_path + } package_map[name] = full } @@ -1098,6 +1104,12 @@ replace_package_alias_node :: proc(node: ^ast.Node, package_map: map[string]stri #partial switch n in node.derived { case ^Bad_Expr: case ^Ident: + // Replace stand-alone identifiers that are package aliases + if package_name, ok := package_map[n.name]; ok { + n.name = get_index_unique_string(collection, package_name) + } else if strings.contains(n.name, "/") { + n.name = get_index_unique_string(collection, n.name) + } case ^Implicit: case ^Undef: case ^Basic_Lit: diff --git a/src/server/generics.odin b/src/server/generics.odin index 347f484..71a0da7 100644 --- a/src/server/generics.odin +++ b/src/server/generics.odin @@ -39,6 +39,8 @@ resolve_poly :: proc( if ident, ok := unwrap_ident(type); ok { if untyped_value, ok := call_symbol.value.(SymbolUntypedValue); ok { save_poly_map(ident, symbol_to_expr(call_symbol, call_node.pos.file), poly_map) + } else if .Anonymous in call_symbol.flags { + save_poly_map(ident, call_node, poly_map) } else { save_poly_map( ident, @@ -515,6 +517,7 @@ resolve_generic_function_symbol :: proc( if file == "" { file = call_expr.args[i].pos.file } + symbol_expr := symbol_to_expr(symbol, file, context.temp_allocator) if symbol_expr == nil { @@ -525,17 +528,13 @@ resolve_generic_function_symbol :: proc( if symbol_value, ok := symbol.value.(SymbolProcedureValue); ok && len(symbol_value.return_types) > 0 { call_arg_count += get_proc_return_value_count(symbol_value.return_types) if _, ok := call_expr.args[i].derived.(^ast.Call_Expr); ok { - if symbol_value.return_types[0].type != nil { - if symbol, ok = resolve_type_expression(ast_context, symbol_value.return_types[0].type); - ok { - symbol_expr = symbol_to_expr( - symbol, - call_expr.args[i].pos.file, - context.temp_allocator, - ) - if symbol_expr == nil { - return {}, false - } + ret_type := symbol_value.return_types[0].type + if ret_type == nil { + ret_type = symbol_value.return_types[0].default_value + } + if ret_type != nil { + if symbol, ok = resolve_type_expression(ast_context, ret_type); ok { + symbol_expr = ret_type } } } diff --git a/src/server/references.odin b/src/server/references.odin index aeebd43..ddd5f17 100644 --- a/src/server/references.odin +++ b/src/server/references.odin @@ -14,29 +14,6 @@ import "core:strings" import "src:common" -fullpaths: [dynamic]string - -walk_directories :: proc(info: os.File_Info, in_err: os.Error, user_data: rawptr) -> (err: os.Error, skip_dir: bool) { - document := cast(^Document)user_data - - if info.type == .Directory { - return nil, false - } - - if info.fullpath == "" { - return nil, false - } - - if strings.contains(info.name, ".odin") { - slash_path, _ := filepath.replace_path_separators(info.fullpath, '/', context.temp_allocator) - if slash_path != document.fullpath { - append(&fullpaths, strings.clone(info.fullpath, context.temp_allocator)) - } - } - - return nil, false -} - prepare_references :: proc( document: ^Document, ast_context: ^AstContext, @@ -236,12 +213,13 @@ resolve_references :: proc( ast_context: ^AstContext, position_context: ^DocumentPositionContext, current_file_only := false, + include_declaration := true, ) -> ( []common.Location, bool, ) { locations := make([dynamic]common.Location, 0, ast_context.allocator) - fullpaths = make([dynamic]string, 0, ast_context.allocator) + fullpaths := make([dynamic]string, 0, ast_context.allocator) symbol, resolve_flag, ok := prepare_references(document, ast_context, position_context) @@ -253,9 +231,13 @@ resolve_references :: proc( for k, v in symbols_and_nodes { if strings.equal_fold(v.symbol.uri, symbol.uri) && v.symbol.range == symbol.range { node_uri := common.create_uri(v.node.pos.file, ast_context.allocator) - range := common.get_token_range(v.node^, ast_context.file.src) + if !include_declaration && v.symbol.range == range && strings.equal_fold(node_uri.uri, symbol.uri) { + // This is the declaration and so we skip it + continue + } + //We don't have to have the `.` with, otherwise it renames the dot. if _, ok := v.node.derived.(^ast.Implicit_Selector_Expr); ok { range.start.character += 1 @@ -309,9 +291,9 @@ resolve_references :: proc( context.allocator = runtime.arena_allocator(&arena) - fullpaths := slice.unique(fullpaths[:]) + paths := slice.unique(fullpaths[:]) - for fullpath in fullpaths { + for fullpath in paths { dir := filepath.dir(fullpath) base := filepath.base(dir) forward_dir, _ := filepath.replace_path_separators(dir, '/', context.allocator) @@ -383,6 +365,13 @@ resolve_references :: proc( if strings.equal_fold(v.symbol.uri, symbol.uri) && v.symbol.range == symbol.range { node_uri := common.create_uri(v.node.pos.file, ast_context.allocator) range := common.get_token_range(v.node^, string(document.text)) + + if !include_declaration && + v.symbol.range == range && + strings.equal_fold(node_uri.uri, symbol.uri) { + // This is the declaration and so we skip it + continue + } //We don't have to have the `.` with, otherwise it renames the dot. if _, ok := v.node.derived.(^ast.Implicit_Selector_Expr); ok { range.start.character += 1 @@ -406,6 +395,7 @@ get_references :: proc( document: ^Document, position: common.Position, current_file_only := false, + include_declaration := true, ) -> ( []common.Location, bool, @@ -435,7 +425,13 @@ get_references :: proc( get_locals(document.ast, position_context.function, &ast_context, &position_context) } - locations, ok2 := resolve_references(document, &ast_context, &position_context, current_file_only) + locations, ok2 := resolve_references( + document, + &ast_context, + &position_context, + current_file_only, + include_declaration = include_declaration, + ) temp_locations := make([dynamic]common.Location, 0, context.temp_allocator) diff --git a/src/server/requests.odin b/src/server/requests.odin index 3aec455..57d190d 100644 --- a/src/server/requests.odin +++ b/src/server/requests.odin @@ -214,12 +214,7 @@ read_and_parse_body :: proc(reader: ^Reader, header: Header) -> (json.Value, boo return value, true } -call_map: map[string]proc( - _: json.Value, - _: RequestId, - _: ^common.Config, - _: ^Writer, -) -> common.Error = { +call_map: map[string]proc(_: json.Value, _: RequestId, _: ^common.Config, _: ^Writer) -> common.Error = { "initialize" = request_initialize, "initialized" = request_initialized, "shutdown" = request_shutdown, @@ -331,10 +326,7 @@ call :: proc(value: json.Value, id: RequestId, writer: ^Writer, config: ^common. if !ok { log.errorf("Failed to find method: %#v", root) - response := make_response_message_error( - id = id, - error = ResponseError{code = .MethodNotFound, message = ""}, - ) + response := make_response_message_error(id = id, error = ResponseError{code = .MethodNotFound, message = ""}) send_error(response, writer) return } @@ -352,10 +344,7 @@ call :: proc(value: json.Value, id: RequestId, writer: ^Writer, config: ^common. } else { err := fn(root["params"], id, config, writer) if err != .None { - response := make_response_message_error( - id = id, - error = ResponseError{code = err, message = ""}, - ) + response := make_response_message_error(id = id, error = ResponseError{code = err, message = ""}) send_error(response, writer) } } @@ -364,20 +353,13 @@ call :: proc(value: json.Value, id: RequestId, writer: ^Writer, config: ^common. //log.errorf("time duration %v for %v", time.duration_milliseconds(diff), method) } -read_ols_initialize_options :: proc( - config: ^common.Config, - ols_config: OlsConfig, - uri: common.Uri, -) { - config.disable_parser_errors = - ols_config.disable_parser_errors.(bool) or_else config.disable_parser_errors +read_ols_initialize_options :: proc(config: ^common.Config, ols_config: OlsConfig, uri: common.Uri) { + config.disable_parser_errors = ols_config.disable_parser_errors.(bool) or_else config.disable_parser_errors config.thread_count = ols_config.thread_pool_count.(int) or_else config.thread_count - config.enable_document_symbols = - ols_config.enable_document_symbols.(bool) or_else config.enable_document_symbols + config.enable_document_symbols = ols_config.enable_document_symbols.(bool) or_else config.enable_document_symbols config.enable_format = ols_config.enable_format.(bool) or_else config.enable_format config.enable_hover = ols_config.enable_hover.(bool) or_else config.enable_hover - config.enable_semantic_tokens = - ols_config.enable_semantic_tokens.(bool) or_else config.enable_semantic_tokens + config.enable_semantic_tokens = ols_config.enable_semantic_tokens.(bool) or_else config.enable_semantic_tokens config.enable_unused_imports_reporting = ols_config.enable_unused_imports_reporting.(bool) or_else config.enable_unused_imports_reporting config.enable_procedure_context = @@ -388,20 +370,20 @@ read_ols_initialize_options :: proc( ols_config.enable_document_highlights.(bool) or_else config.enable_document_highlights config.enable_completion_matching = ols_config.enable_completion_matching.(bool) or_else config.enable_completion_matching - config.enable_document_links = - ols_config.enable_document_links.(bool) or_else config.enable_document_links + config.enable_document_links = ols_config.enable_document_links.(bool) or_else config.enable_document_links config.enable_comp_lit_signature_help = ols_config.enable_comp_lit_signature_help.(bool) or_else config.enable_comp_lit_signature_help config.enable_comp_lit_signature_help_use_docs = ols_config.enable_comp_lit_signature_help_use_docs.(bool) or_else config.enable_comp_lit_signature_help_use_docs + config.enable_code_action_invert_if = + ols_config.enable_code_action_invert_if.(bool) or_else config.enable_code_action_invert_if config.verbose = ols_config.verbose.(bool) or_else config.verbose config.file_log = ols_config.file_log.(bool) or_else config.file_log config.enable_procedure_snippet = ols_config.enable_procedure_snippet.(bool) or_else config.enable_procedure_snippet - config.enable_auto_import = - ols_config.enable_auto_import.(bool) or_else config.enable_auto_import + config.enable_auto_import = ols_config.enable_auto_import.(bool) or_else config.enable_auto_import config.enable_checker_only_saved = ols_config.enable_checker_only_saved.(bool) or_else config.enable_checker_only_saved @@ -417,10 +399,7 @@ read_ols_initialize_options :: proc( } if ols_config.odin_root_override != "" { - config.odin_root_override = strings.clone( - ols_config.odin_root_override, - context.temp_allocator, - ) + config.odin_root_override = strings.clone(ols_config.odin_root_override, context.temp_allocator) allocated: bool config.odin_root_override, allocated = common.resolve_home_dir(config.odin_root_override) @@ -473,8 +452,7 @@ read_ols_initialize_options :: proc( config.enable_inlay_hints_implicit_return = ols_config.enable_inlay_hints_implicit_return.(bool) or_else config.enable_inlay_hints_implicit_return - config.enable_fake_method = - ols_config.enable_fake_methods.(bool) or_else config.enable_fake_method + config.enable_fake_method = ols_config.enable_fake_methods.(bool) or_else config.enable_fake_method config.enable_overload_resolution = ols_config.enable_overload_resolution.(bool) or_else config.enable_overload_resolution @@ -512,10 +490,7 @@ read_ols_initialize_options :: proc( } else { final_path, _ = filepath.replace_path_separators( common.get_case_sensitive_path( - path.join( - elems = {uri.path, forward_path}, - allocator = context.temp_allocator, - ), + path.join(elems = {uri.path, forward_path}, allocator = context.temp_allocator), context.temp_allocator, ), '/', @@ -537,11 +512,7 @@ read_ols_initialize_options :: proc( log.errorf("Failed to find absolute address of collection: %v", final_path, err) config.collections[strings.clone(it.name)] = strings.clone(final_path) } else { - slashed_path, _ := filepath.replace_path_separators( - abs_final_path, - '/', - context.temp_allocator, - ) + slashed_path, _ := filepath.replace_path_separators(abs_final_path, '/', context.temp_allocator) config.collections[strings.clone(it.name)] = strings.clone(slashed_path) } @@ -587,8 +558,7 @@ read_ols_initialize_options :: proc( } if odin_core_env != "" { - if abs_core_env, err := filepath.abs(odin_core_env, context.temp_allocator); - err == nil { + if abs_core_env, err := filepath.abs(odin_core_env, context.temp_allocator); err == nil { odin_core_env = abs_core_env } } @@ -599,11 +569,7 @@ read_ols_initialize_options :: proc( // Insert the default collections if they are not specified in the config. if odin_core_env != "" { - forward_path, _ := filepath.replace_path_separators( - odin_core_env, - '/', - context.temp_allocator, - ) + forward_path, _ := filepath.replace_path_separators(odin_core_env, '/', context.temp_allocator) // base if "base" not_in config.collections { @@ -631,10 +597,7 @@ read_ols_initialize_options :: proc( // shared if "shared" not_in config.collections { - shared_path := path.join( - elems = {forward_path, "shared"}, - allocator = context.allocator, - ) + shared_path := path.join(elems = {forward_path, "shared"}, allocator = context.allocator) if os.exists(shared_path) { config.collections[strings.clone("shared")] = shared_path } else { @@ -739,10 +702,7 @@ request_initialize :: proc( read_ols_initialize_options(config, initialize_params.initializationOptions, uri) // Apply ols.json config. - ols_config_path := path.join( - elems = {uri.path, "ols.json"}, - allocator = context.temp_allocator, - ) + ols_config_path := path.join(elems = {uri.path, "ols.json"}, allocator = context.temp_allocator) read_ols_config(ols_config_path, config, uri) } else { read_ols_initialize_options(config, initialize_params.initializationOptions, {}) @@ -763,8 +723,7 @@ request_initialize :: proc( config.enable_label_details = initialize_params.capabilities.textDocument.completion.completionItem.labelDetailsSupport - config.enable_snippets &= - initialize_params.capabilities.textDocument.completion.completionItem.snippetSupport + config.enable_snippets &= initialize_params.capabilities.textDocument.completion.completionItem.snippetSupport config.signature_offset_support = initialize_params.capabilities.textDocument.signatureHelp.signatureInformation.parameterInformation.labelOffsetSupport @@ -773,17 +732,12 @@ request_initialize :: proc( signatureTriggerCharacters := []string{"(", ","} signatureRetriggerCharacters := []string{","} - semantic_range_support := - initialize_params.capabilities.textDocument.semanticTokens.requests.range + semantic_range_support := initialize_params.capabilities.textDocument.semanticTokens.requests.range response := make_response_message( params = ResponseInitializeParams { capabilities = ServerCapabilities { - textDocumentSync = TextDocumentSyncOptions { - openClose = true, - change = 2, - save = {includeText = true}, - }, + textDocumentSync = TextDocumentSyncOptions{openClose = true, change = 2, save = {includeText = true}}, renameProvider = RenameOptions{prepareProvider = true}, workspaceSymbolProvider = true, referencesProvider = config.enable_references, @@ -814,10 +768,7 @@ request_initialize :: proc( hoverProvider = config.enable_hover, documentFormattingProvider = config.enable_format, documentLinkProvider = {resolveProvider = false}, - codeActionProvider = { - resolveProvider = false, - codeActionKinds = {"refactor.rewrite"}, - }, + codeActionProvider = {resolveProvider = false, codeActionKinds = {"refactor.rewrite"}}, }, }, id = id, @@ -883,12 +834,7 @@ request_initialized :: proc( return .None } -request_shutdown :: proc( - params: json.Value, - id: RequestId, - config: ^common.Config, - writer: ^Writer, -) -> common.Error { +request_shutdown :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error { response := make_response_message(params = nil, id = id) send_response(response, writer) @@ -1003,12 +949,7 @@ request_completion :: proc( } list: CompletionList - list, ok = get_completion_list( - document, - completition_params.position, - completition_params.context_, - config, - ) + list, ok = get_completion_list(document, completition_params.position, completition_params.context_, config) if !ok { return .InternalError @@ -1102,12 +1043,7 @@ request_format_document :: proc( return .None } -notification_exit :: proc( - params: json.Value, - id: RequestId, - config: ^common.Config, - writer: ^Writer, -) -> common.Error { +notification_exit :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error { config.running = false return .None } @@ -1134,12 +1070,7 @@ notification_did_open :: proc( defer delete(open_params.textDocument.uri) - if n := document_open( - open_params.textDocument.uri, - open_params.textDocument.text, - config, - writer, - ); n != .None { + if n := document_open(open_params.textDocument.uri, open_params.textDocument.text, config, writer); n != .None { return .InternalError } @@ -1375,12 +1306,7 @@ request_document_symbols :: proc( return .None } -request_hover :: proc( - params: json.Value, - id: RequestId, - config: ^common.Config, - writer: ^Writer, -) -> common.Error { +request_hover :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error { params_object, ok := params.(json.Object) if !ok { @@ -1528,12 +1454,7 @@ request_prepare_rename :: proc( return .None } -request_rename :: proc( - params: json.Value, - id: RequestId, - config: ^common.Config, - writer: ^Writer, -) -> common.Error { +request_rename :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error { params_object, ok := params.(json.Object) if !ok { @@ -1580,7 +1501,13 @@ request_references :: proc( reference_param: ReferenceParams - if unmarshal(params, reference_param, context.temp_allocator) != nil { + // Due to the field named `context`, we need to use json tags and this is the easiest way to handle that right now. + data, err := json.marshal(params_object) + if err != nil { + return .ParseError + } + + if err := json.unmarshal(data, &reference_param, allocator = context.temp_allocator); err != nil { return .ParseError } @@ -1591,7 +1518,11 @@ request_references :: proc( } locations: []common.Location - locations, ok = get_references(document, reference_param.position) + locations, ok = get_references( + document, + reference_param.position, + include_declaration = reference_param.ctx.includeDeclaration, + ) if !ok { return .InternalError @@ -1783,11 +1714,6 @@ request_workspace_symbols :: proc( return .None } -request_noop :: proc( - params: json.Value, - id: RequestId, - config: ^common.Config, - writer: ^Writer, -) -> common.Error { +request_noop :: proc(params: json.Value, id: RequestId, config: ^common.Config, writer: ^Writer) -> common.Error { return .None } diff --git a/src/server/types.odin b/src/server/types.odin index c65e181..fc52119 100644 --- a/src/server/types.odin +++ b/src/server/types.odin @@ -436,6 +436,7 @@ OlsConfig :: struct { enable_procedure_snippet: Maybe(bool), enable_checker_only_saved: Maybe(bool), enable_auto_import: Maybe(bool), + enable_code_action_invert_if: Maybe(bool), disable_parser_errors: Maybe(bool), verbose: Maybe(bool), file_log: Maybe(bool), @@ -561,9 +562,14 @@ PrepareRenameParams :: struct { position: common.Position, } +ReferenceContext :: struct { + includeDeclaration: bool, +} + ReferenceParams :: struct { textDocument: TextDocumentIdentifier, position: common.Position, + ctx: ReferenceContext `json:"context"`, } HighlightParams :: struct { diff --git a/src/testing/testing.odin b/src/testing/testing.odin index 5e927a7..46c597f 100644 --- a/src/testing/testing.odin +++ b/src/testing/testing.odin @@ -464,11 +464,12 @@ expect_reference_locations :: proc( src: ^Source, expect_locations: []common.Location, expect_excluded: []common.Location = nil, + include_declaration := true, ) { setup(src) defer teardown(src) - locations, ok := server.get_references(src.document, src.position) + locations, ok := server.get_references(src.document, src.position, include_declaration = include_declaration) for expect_location in expect_locations { match := false |