aboutsummaryrefslogtreecommitdiff
path: root/src/server
diff options
context:
space:
mode:
authorBrad Lewis <22850972+BradLewis@users.noreply.github.com>2025-06-09 14:32:59 -0400
committerBrad Lewis <22850972+BradLewis@users.noreply.github.com>2025-06-13 15:23:36 -0400
commit8816d531da666959b7df0c64401b8fa064b0cd3d (patch)
tree893cf9b15bb610f13b2c6cf3c2c3148166341ed4 /src/server
parent023c8a0fd1d8d8e54b38bf990a74816a31dee68e (diff)
Add textDocument/typeDefinition support
Diffstat (limited to 'src/server')
-rw-r--r--src/server/analysis.odin238
-rw-r--r--src/server/requests.odin42
-rw-r--r--src/server/type_definition.odin246
-rw-r--r--src/server/types.odin1
4 files changed, 452 insertions, 75 deletions
diff --git a/src/server/analysis.odin b/src/server/analysis.odin
index 1a7e4a7..9e43467 100644
--- a/src/server/analysis.odin
+++ b/src/server/analysis.odin
@@ -796,6 +796,39 @@ check_node_recursion :: proc(ast_context: ^AstContext, node: ^ast.Node) -> bool
return false
}
+// Resolves the location of the underlying type of the expression
+resolve_location_type_expression :: proc(ast_context: ^AstContext, node: ^ast.Expr) -> (Symbol, bool) {
+ if node == nil {
+ return {}, false
+ }
+
+ //Try to prevent stack overflows and prevent indexing out of bounds.
+ if ast_context.deferred_count >= DeferredDepth {
+ return {}, false
+ }
+
+ set_ast_package_scoped(ast_context)
+
+ if check_node_recursion(ast_context, node) {
+ return {}, false
+ }
+
+ // TODO: there is likely more of these that will be needed as well.
+ // I think we'll need a better way to manage this all
+ #partial switch n in node.derived {
+ case ^ast.Ident:
+ if _, ok := common.keyword_map[n.name]; ok {
+ return {}, true
+ }
+ return resolve_location_type_identifier(ast_context, n^)
+ case ^ast.Basic_Lit:
+ return {}, true
+ case ^ast.Array_Type:
+ return resolve_location_type_expression(ast_context, n.elem)
+ }
+ return resolve_type_expression(ast_context, node)
+}
+
resolve_type_expression :: proc(ast_context: ^AstContext, node: ^ast.Expr) -> (Symbol, bool) {
clear(&ast_context.recursion_map)
return internal_resolve_type_expression(ast_context, node)
@@ -929,7 +962,7 @@ internal_resolve_type_expression :: proc(ast_context: ^AstContext, node: ^ast.Ex
}
case ^Proc_Lit:
if v.type.results != nil {
- if len(v.type.results.list) == 1 {
+ if len(v.type.results.list) > 0 {
return internal_resolve_type_expression(ast_context, v.type.results.list[0].type)
}
}
@@ -1004,96 +1037,123 @@ internal_resolve_type_expression :: proc(ast_context: ^AstContext, node: ^ast.Ex
return selector, true
}
case ^Selector_Expr:
- if selector, ok := internal_resolve_type_expression(ast_context, v.expr); ok {
- ast_context.use_locals = false
+ return resolve_selector_expression(ast_context, v)
+ case ^ast.Poly_Type:
+ if v.specialization != nil {
+ return internal_resolve_type_expression(ast_context, v.specialization)
+ }
- set_ast_package_from_symbol_scoped(ast_context, selector)
+ case:
+ log.warnf("default node kind, internal_resolve_type_expression: %v", v)
+ }
- #partial switch s in selector.value {
- case SymbolFixedArrayValue:
- components_count := 0
- for c in v.field.name {
- if c == 'x' || c == 'y' || c == 'z' || c == 'w' || c == 'r' || c == 'g' || c == 'b' || c == 'a' {
- components_count += 1
- } else {
- return {}, false
- }
- }
+ return Symbol{}, false
+}
- if components_count == 0 {
- return {}, false
- }
+resolve_selector_expression :: proc(ast_context: ^AstContext, node: ^ast.Selector_Expr) -> (Symbol, bool) {
+ if selector, ok := internal_resolve_type_expression(ast_context, node.expr); ok {
+ ast_context.use_locals = false
- if components_count == 1 {
- set_ast_package_from_symbol_scoped(ast_context, selector)
+ set_ast_package_from_symbol_scoped(ast_context, selector)
- symbol, ok := internal_resolve_type_expression(ast_context, s.expr)
- symbol.type = .Variable
- return symbol, ok
+ #partial switch s in selector.value {
+ case SymbolFixedArrayValue:
+ components_count := 0
+ for c in node.field.name {
+ if c == 'x' || c == 'y' || c == 'z' || c == 'w' || c == 'r' || c == 'g' || c == 'b' || c == 'a' {
+ components_count += 1
} else {
- value := SymbolFixedArrayValue {
- expr = s.expr,
- len = make_int_basic_value(ast_context, components_count, s.len.pos, s.len.end),
- }
- selector.value = value
- selector.type = .Variable
- return selector, true
+ return {}, false
}
- case SymbolProcedureValue:
- if len(s.return_types) == 1 {
- selector_expr := new_type(
- ast.Selector_Expr,
- s.return_types[0].node.pos,
- s.return_types[0].node.end,
- ast_context.allocator,
- )
- selector_expr.expr = s.return_types[0].type
- selector_expr.field = v.field
+ }
+
+ if components_count == 0 {
+ return {}, false
+ }
+
+ if components_count == 1 {
+ set_ast_package_from_symbol_scoped(ast_context, selector)
- return internal_resolve_type_expression(ast_context, selector_expr)
+ symbol, ok := internal_resolve_type_expression(ast_context, s.expr)
+ symbol.type = .Variable
+ return symbol, ok
+ } else {
+ value := SymbolFixedArrayValue {
+ expr = s.expr,
+ len = make_int_basic_value(ast_context, components_count, s.len.pos, s.len.end),
}
- case SymbolStructValue:
- for name, i in s.names {
- if v.field != nil && name == v.field.name {
- ast_context.field_name = v.field^
- symbol, ok := internal_resolve_type_expression(ast_context, s.types[i])
- symbol.type = .Variable
- return symbol, ok
- }
+ selector.value = value
+ selector.type = .Variable
+ return selector, true
+ }
+ case SymbolProcedureValue:
+ if len(s.return_types) == 1 {
+ selector_expr := new_type(
+ ast.Selector_Expr,
+ s.return_types[0].node.pos,
+ s.return_types[0].node.end,
+ ast_context.allocator,
+ )
+ selector_expr.expr = s.return_types[0].type
+ selector_expr.field = node.field
+
+ return internal_resolve_type_expression(ast_context, selector_expr)
+ }
+ case SymbolStructValue:
+ for name, i in s.names {
+ if node.field != nil && name == node.field.name {
+ ast_context.field_name = node.field^
+ symbol, ok := internal_resolve_type_expression(ast_context, s.types[i])
+ symbol.type = .Variable
+ return symbol, ok
}
- case SymbolBitFieldValue:
- for name, i in s.names {
- if v.field != nil && name == v.field.name {
- ast_context.field_name = v.field^
- symbol, ok := internal_resolve_type_expression(ast_context, s.types[i])
- symbol.type = .Variable
- return symbol, ok
- }
+ }
+ case SymbolBitFieldValue:
+ for name, i in s.names {
+ if node.field != nil && name == node.field.name {
+ ast_context.field_name = node.field^
+ symbol, ok := internal_resolve_type_expression(ast_context, s.types[i])
+ symbol.type = .Variable
+ return symbol, ok
}
- case SymbolPackageValue:
- try_build_package(ast_context.current_package)
+ }
+ case SymbolPackageValue:
+ try_build_package(ast_context.current_package)
- if v.field != nil {
- return resolve_symbol_return(ast_context, lookup(v.field.name, selector.pkg))
- } else {
- return Symbol{}, false
- }
- case SymbolEnumValue:
- // enum members probably require own symbol value
- selector.type = .EnumMember
- return selector, true
+ if node.field != nil {
+ return resolve_symbol_return(ast_context, lookup(node.field.name, selector.pkg))
+ } else {
+ return Symbol{}, false
}
+ case SymbolEnumValue:
+ // enum members probably require own symbol value
+ selector.type = .EnumMember
+ return selector, true
}
- case ^ast.Poly_Type:
- if v.specialization != nil {
- return internal_resolve_type_expression(ast_context, v.specialization)
- }
-
- case:
- log.warnf("default node kind, internal_resolve_type_expression: %v", v)
}
- return Symbol{}, false
+ return {}, false
+}
+
+// returns the symbol of the first return type of a proc
+resolve_symbol_proc_first_return_symbol :: proc(ast_context: ^AstContext, symbol: Symbol) -> (Symbol, bool) {
+ if v, ok := symbol.value.(SymbolProcedureValue); ok {
+ if len(v.return_types) > 0 {
+ if ast_context.current_package != symbol.pkg {
+ current_package := ast_context.current_package
+ defer {
+ ast_context.current_package = current_package
+ }
+ ast_context.current_package = symbol.pkg
+ return resolve_type_expression(ast_context, v.return_types[0].type)
+ } else {
+ return resolve_location_type_expression(ast_context, v.return_types[0].type)
+ }
+ } else {
+ return {}, true
+ }
+ }
+ return {}, false
}
store_local :: proc(
@@ -1962,6 +2022,33 @@ resolve_unresolved_symbol :: proc(ast_context: ^AstContext, symbol: ^Symbol) ->
return true
}
+// Resolves the location of the underlying type of the identifier
+resolve_location_type_identifier :: proc(ast_context: ^AstContext, node: ast.Ident) -> (Symbol, bool) {
+ // TODO: will likely need to clean this up and find a way for this flow to make sense.
+ // Ideally we need a way to extract the full symbol of a global
+ if local, ok := get_local(ast_context^, node); ok {
+ #partial switch n in local.rhs.derived {
+ case ^ast.Ident:
+ return resolve_location_identifier(ast_context, n^)
+ case ^ast.Basic_Lit:
+ return {}, true
+ case ^ast.Array_Type:
+ if elem, ok := n.elem.derived.(^ast.Ident); ok {
+ return resolve_location_identifier(ast_context, elem^)
+ }
+ case ^ast.Selector_Expr:
+ return resolve_selector_expression(ast_context, n)
+ }
+ } else if global, ok := ast_context.globals[node.name]; ok {
+ if v, ok := global.expr.derived.(^ast.Proc_Lit); ok {
+ if symbol, ok := resolve_type_expression(ast_context, global.name_expr); ok {
+ return symbol, ok
+ }
+ }
+ }
+ return resolve_location_identifier(ast_context, node)
+}
+
resolve_location_identifier :: proc(ast_context: ^AstContext, node: ast.Ident) -> (Symbol, bool) {
symbol: Symbol
@@ -1994,6 +2081,7 @@ resolve_location_identifier :: proc(ast_context: ^AstContext, node: ast.Ident) -
return {}, false
}
+
resolve_location_comp_lit_field :: proc(
ast_context: ^AstContext,
position_context: ^DocumentPositionContext,
diff --git a/src/server/requests.odin b/src/server/requests.odin
index d5a3481..837bc9b 100644
--- a/src/server/requests.odin
+++ b/src/server/requests.odin
@@ -228,6 +228,7 @@ call_map: map[string]proc(_: json.Value, _: RequestId, _: ^common.Config, _: ^Wr
"textDocument/didClose" = notification_did_close,
"textDocument/didSave" = notification_did_save,
"textDocument/definition" = request_definition,
+ "textDocument/typeDefinition" = request_type_definition,
"textDocument/completion" = request_completion,
"textDocument/signatureHelp" = request_signature_help,
"textDocument/documentSymbol" = request_document_symbols,
@@ -680,6 +681,7 @@ request_initialize :: proc(
workspaceSymbolProvider = true,
referencesProvider = config.enable_references,
definitionProvider = true,
+ typeDefinitionProvider = true,
completionProvider = CompletionOptions {
resolveProvider = false,
triggerCharacters = completionTriggerCharacters,
@@ -813,6 +815,46 @@ request_definition :: proc(
return .None
}
+request_type_definition :: proc(
+ params: json.Value,
+ id: RequestId,
+ config: ^common.Config,
+ writer: ^Writer,
+) -> common.Error {
+ params_object, ok := params.(json.Object)
+
+ if !ok {
+ return .ParseError
+ }
+
+ definition_params: TextDocumentPositionParams
+
+ if unmarshal(params, definition_params, context.temp_allocator) != nil {
+ return .ParseError
+ }
+
+ document := document_get(definition_params.textDocument.uri)
+
+ if document == nil {
+ return .InternalError
+ }
+
+ locations, ok2 := get_type_definition_locations(document, definition_params.position)
+ if !ok2 {
+ log.warn("Failed to get type definition location")
+ }
+
+ if len(locations) == 1 {
+ response := make_response_message(params = locations[0], id = id)
+ send_response(response, writer)
+ } else {
+ response := make_response_message(params = locations, id = id)
+ send_response(response, writer)
+ }
+
+ return .None
+}
+
request_completion :: proc(
params: json.Value,
id: RequestId,
diff --git a/src/server/type_definition.odin b/src/server/type_definition.odin
new file mode 100644
index 0000000..2b5ec81
--- /dev/null
+++ b/src/server/type_definition.odin
@@ -0,0 +1,246 @@
+package server
+
+import "core:fmt"
+import "core:log"
+import "core:mem"
+import "core:strings"
+import "core:odin/ast"
+
+import "src:common"
+
+@(private = "file")
+append_symbol_to_locations :: proc(locations: ^[dynamic]common.Location, document: ^Document, symbol: Symbol) {
+ if symbol.range == {} {
+ return
+ }
+ location := common.Location{}
+ location.range = symbol.range
+ if symbol.uri == "" {
+ location.uri = document.uri.uri
+ } else {
+ location.uri = symbol.uri
+ }
+ append(locations, location)
+}
+
+get_type_definition_locations :: proc(document: ^Document, position: common.Position) -> ([]common.Location, bool) {
+ uri: string
+ locations := make([dynamic]common.Location, context.temp_allocator)
+
+ position_context, ok := get_document_position_context(document, position, .Definition)
+
+ if !ok {
+ log.warn("Failed to get position context")
+ return {}, false
+ }
+
+ ast_context := make_ast_context(
+ document.ast,
+ document.imports,
+ document.package_name,
+ document.uri.uri,
+ document.fullpath,
+ )
+
+ ast_context.position_hint = position_context.hint
+
+ get_globals(document.ast, &ast_context)
+
+ if position_context.function != nil {
+ get_locals(document.ast, position_context.function, &ast_context, &position_context)
+ }
+
+ if position_context.import_stmt != nil {
+ return {}, false
+ }
+
+ if position_context.identifier != nil {
+ if ident, ok := position_context.identifier.derived.(^ast.Ident); ok {
+ if _, ok := common.keyword_map[ident.name]; ok {
+ return {}, false
+ }
+
+ if str, ok := builtin_identifier_hover[ident.name]; ok {
+ return {}, false
+ }
+ }
+ }
+
+ if position_context.call != nil {
+ if call, ok := position_context.call.derived.(^ast.Call_Expr); ok {
+ if !position_in_exprs(call.args, position_context.position) {
+ if call_symbol, ok := resolve_type_expression(&ast_context, position_context.call); ok {
+ if symbol, ok := resolve_symbol_proc_first_return_symbol(&ast_context, call_symbol); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ return {}, false
+ }
+ }
+ }
+ }
+
+ if position_context.struct_type != nil {
+ for field in position_context.struct_type.fields.list {
+ for name in field.names {
+ if position_in_node(name, position_context.position) {
+ if identifier, ok := name.derived.(^ast.Ident); ok && field.type != nil {
+ if position_context.value_decl != nil && len(position_context.value_decl.names) != 0 {
+ if symbol, ok := resolve_location_type_expression(&ast_context, field.type); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if position_context.field_value != nil && position_context.comp_lit != nil {
+ if comp_symbol, ok := resolve_comp_literal(&ast_context, &position_context); ok {
+ if field, ok := position_context.field_value.field.derived.(^ast.Ident); ok {
+ if position_in_node(field, position_context.position) {
+ if v, ok := comp_symbol.value.(SymbolStructValue); ok {
+ for name, i in v.names {
+ if name == field.name {
+ if symbol, ok := resolve_location_type_expression(&ast_context, v.types[i]); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ }
+ }
+ }
+ } else if v, ok := comp_symbol.value.(SymbolBitFieldValue); ok {
+ for name, i in v.names {
+ if name == field.name {
+ if symbol, ok := resolve_type_expression(&ast_context, v.types[i]); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if position_context.selector != nil &&
+ position_context.identifier != nil &&
+ position_context.field == position_context.identifier {
+ reset_ast_context(&ast_context)
+
+ ast_context.current_package = ast_context.document_package
+
+ //if the base selector is the client wants to go to.
+ if base, ok := position_context.selector.derived.(^ast.Ident); ok && position_context.identifier != nil {
+ ident := position_context.identifier.derived.(^ast.Ident)^
+
+ if position_in_node(base, position_context.position) {
+ if symbol, ok := resolve_location_type_identifier(&ast_context, ident); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ }
+ }
+
+ selector: Symbol
+
+ selector, ok = resolve_type_expression(&ast_context, position_context.selector)
+
+ if !ok {
+ return {}, false
+ }
+
+ field: string
+
+ if position_context.field != nil {
+ #partial switch v in position_context.field.derived {
+ case ^ast.Ident:
+ field = v.name
+ }
+ }
+
+ if v, is_proc := selector.value.(SymbolProcedureValue); is_proc {
+ if len(v.return_types) == 0 || v.return_types[0].type == nil {
+ return {}, false
+ }
+
+ set_ast_package_set_scoped(&ast_context, selector.pkg)
+
+ if selector, ok = resolve_location_type_expression(&ast_context, v.return_types[0].type); !ok {
+ return {}, false
+ }
+ }
+
+ ast_context.current_package = selector.pkg
+
+ #partial switch v in selector.value {
+ case SymbolStructValue:
+ for name, i in v.names {
+ if name == field {
+ if symbol, ok := resolve_location_type_expression(&ast_context, v.types[i]); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ }
+ }
+ case SymbolBitFieldValue:
+ for name, i in v.names {
+ if name == field {
+ if symbol, ok := resolve_type_expression(&ast_context, v.types[i]); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ }
+ }
+ case SymbolPackageValue:
+ if position_context.field != nil {
+ if ident, ok := position_context.field.derived.(^ast.Ident); ok {
+ // check to see if we are in a position call context
+ if position_context.call != nil && ast_context.call == nil {
+ if call, ok := position_context.call.derived.(^ast.Call_Expr); ok {
+ if !position_in_exprs(call.args, position_context.position) {
+ ast_context.call = call
+ }
+ }
+ }
+ if symbol, ok := resolve_type_identifier(&ast_context, ident^); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ }
+ }
+ }
+ } else if position_context.identifier != nil {
+ reset_ast_context(&ast_context)
+
+ ast_context.current_package = ast_context.document_package
+
+ ident := position_context.identifier.derived.(^ast.Ident)^
+
+ if position_context.value_decl != nil {
+ ident.pos = position_context.value_decl.end
+ ident.end = position_context.value_decl.end
+ }
+
+ if position_context.call != nil {
+ if call, ok := position_context.call.derived.(^ast.Call_Expr); ok {
+ if !position_in_exprs(call.args, position_context.position) {
+ ast_context.call = call
+ }
+ }
+ }
+
+ if symbol, ok := resolve_location_type_identifier(&ast_context, ident); ok {
+ if symbol, ok := resolve_symbol_proc_first_return_symbol(&ast_context, symbol); ok {
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ append_symbol_to_locations(&locations, document, symbol)
+ return locations[:], true
+ }
+ }
+
+ return {}, false
+}
diff --git a/src/server/types.odin b/src/server/types.odin
index 4a4ef8f..24dc2e8 100644
--- a/src/server/types.odin
+++ b/src/server/types.odin
@@ -132,6 +132,7 @@ MarkupContent :: struct {
ServerCapabilities :: struct {
textDocumentSync: TextDocumentSyncOptions,
definitionProvider: bool,
+ typeDefinitionProvider: bool,
completionProvider: CompletionOptions,
signatureHelpProvider: SignatureHelpOptions,
semanticTokensProvider: SemanticTokensOptions,