diff options
| -rw-r--r-- | src/common/ast.odin | 29 | ||||
| -rw-r--r-- | src/server/analysis.odin | 72 | ||||
| -rw-r--r-- | src/server/generics.odin | 67 | ||||
| -rw-r--r-- | tests/completions_test.odin | 52 |
4 files changed, 158 insertions, 62 deletions
diff --git a/src/common/ast.odin b/src/common/ast.odin index 29d3e6c..b215eb8 100644 --- a/src/common/ast.odin +++ b/src/common/ast.odin @@ -207,6 +207,35 @@ unwrap_pointer_expr :: proc(expr: ^ast.Expr) -> (^ast.Expr, int, bool) { return expr, n, true } +expr_contains_poly :: proc(expr: ^ast.Expr) -> bool { + if expr == nil { + return false + } + + visit :: proc(visitor: ^ast.Visitor, node: ^ast.Node) -> ^ast.Visitor { + if node == nil { + return nil + } + if _, ok := node.derived.(^ast.Poly_Type); ok { + b := cast(^bool)visitor.data + b^ = true + return nil + } + return visitor + } + + found := false + + visitor := ast.Visitor { + visit = visit, + data = &found, + } + + ast.walk(&visitor, expr) + + return found +} + is_expr_basic_lit :: proc(expr: ^ast.Expr) -> bool { _, ok := expr.derived.(^ast.Basic_Lit) return ok diff --git a/src/server/analysis.odin b/src/server/analysis.odin index d6592ae..f111c5d 100644 --- a/src/server/analysis.odin +++ b/src/server/analysis.odin @@ -101,6 +101,7 @@ AstContext :: struct { uri: string, fullpath: string, non_mutable_only: bool, + overloading: bool, } make_ast_context :: proc( @@ -617,6 +618,13 @@ resolve_function_overload :: proc( Symbol, bool, ) { + old_overloading := ast_context.overloading + ast_context.overloading = true + + defer { + ast_context.overloading = old_overloading + } + using ast call_expr := ast_context.call @@ -1445,22 +1453,10 @@ internal_resolve_type_identifier :: proc( make_symbol_bit_field_from_ast(ast_context, v^, node), true return_symbol.name = node.name case ^Proc_Lit: - if !is_procedure_generic(v.type) { - return_symbol, ok = - make_symbol_procedure_from_ast( - ast_context, - local.rhs, - v.type^, - node, - {}, - false, - ), - true - } else { - if return_symbol, ok = resolve_generic_function( - ast_context, - v^, - ); !ok { + if is_procedure_generic(v.type) { + return_symbol, ok = resolve_generic_function(ast_context, v^) + + if !ok && !ast_context.overloading { return_symbol, ok = make_symbol_procedure_from_ast( ast_context, @@ -1472,6 +1468,18 @@ internal_resolve_type_identifier :: proc( ), true } + } else { + + return_symbol, ok = + make_symbol_procedure_from_ast( + ast_context, + local.rhs, + v.type^, + node, + {}, + false, + ), + true } case ^Proc_Group: return_symbol, ok = resolve_function_overload(ast_context, v^) @@ -1581,22 +1589,11 @@ internal_resolve_type_identifier :: proc( make_symbol_bit_field_from_ast(ast_context, v^, node), true return_symbol.name = node.name case ^Proc_Lit: - if !is_procedure_generic(v.type) { - return_symbol, ok = - make_symbol_procedure_from_ast( - ast_context, - global.expr, - v.type^, - node, - global.attributes, - false, - ), - true - } else { - if return_symbol, ok = resolve_generic_function( - ast_context, - v^, - ); !ok { + if is_procedure_generic(v.type) { + return_symbol, ok = resolve_generic_function(ast_context, v^) + + //If we are not overloading just show the unresolved generic function + if !ok && !ast_context.overloading { return_symbol, ok = make_symbol_procedure_from_ast( ast_context, @@ -1608,6 +1605,17 @@ internal_resolve_type_identifier :: proc( ), true } + } else { + return_symbol, ok = + make_symbol_procedure_from_ast( + ast_context, + global.expr, + v.type^, + node, + global.attributes, + false, + ), + true } case ^Proc_Group: return_symbol, ok = resolve_function_overload(ast_context, v^) diff --git a/src/server/generics.odin b/src/server/generics.odin index 9e915c3..5ff5997 100644 --- a/src/server/generics.odin +++ b/src/server/generics.odin @@ -40,6 +40,8 @@ resolve_poly :: proc( case ^ast.Poly_Type: specialization = v.specialization type = v.type + case: + specialization = poly_node } if specialization == nil { @@ -124,7 +126,7 @@ resolve_poly :: proc( if call_struct, ok := call_node.derived.(^ast.Struct_Type); ok { arg_index := 0 struct_value := call_symbol.value.(SymbolStructValue) - + found := false for arg in p.args { if poly_type, ok := arg.derived.(^ast.Poly_Type); ok { if poly_type.type == nil || @@ -140,8 +142,11 @@ resolve_poly :: proc( ) arg_index += 1 + found |= true } } + + return found } case ^ast.Struct_Type: case ^ast.Dynamic_Array_Type: @@ -296,14 +301,13 @@ resolve_poly :: proc( poly_map, ) } + return true } } case ^ast.Ident: - if n, ok := call_node.derived.(^ast.Ident); ok { - return true - } + return true case: - log.error("Unhandled specialization %v", specialization.derived) + return false } return false @@ -585,13 +589,18 @@ resolve_generic_function_symbol :: proc( nil, ) } + } else { + return {}, false } + } else { + return {}, false } i += 1 } } + for k, v in poly_map { find_and_replace_poly_type(v, &poly_map) } @@ -643,28 +652,39 @@ resolve_generic_function_symbol :: proc( append(&return_types, field) } + for param in params { - if len(param.names) == 0 { - continue + field := cast(^ast.Field)clone_node(param, ast_context.allocator, nil) + + if field.type != nil { + if poly_type, ok := field.type.derived.(^ast.Poly_Type); ok { + if expr, ok := poly_map[poly_type.type.name]; ok { + field.type = expr + } + } else { + if ident, ok := unwrap_ident(field.type); ok { + if expr, ok := poly_map[ident.name]; ok { + field.type = expr + } + } + + find_and_replace_poly_type(field.type, &poly_map) + } } - //check the name for poly - if poly_type, ok := param.names[0].derived.(^ast.Poly_Type); - ok && param.type != nil { - if m, ok := poly_map[poly_type.type.name]; ok { - field := cast(^ast.Field)clone_node( - param, - ast_context.allocator, - nil, - ) - field.type = m - append(&argument_types, field) + if len(param.names) > 0 { + if poly_type, ok := param.names[0].derived.(^ast.Poly_Type); + ok && param.type != nil { + if m, ok := poly_map[poly_type.type.name]; ok { + field.type = m + } } - } else { - append(&argument_types, param) } + + append(&argument_types, field) } + symbol.value = SymbolProcedureValue { return_types = return_types[:], arg_types = argument_types[:], @@ -683,12 +703,9 @@ is_procedure_generic :: proc(proc_type: ^ast.Proc_Type) -> bool { continue } - if expr, _, ok := common.unwrap_pointer_expr(param.type); ok { - if _, ok := expr.derived.(^ast.Poly_Type); ok { - return true - } + if common.expr_contains_poly(param.type) { + return true } - } return false diff --git a/tests/completions_test.odin b/tests/completions_test.odin index 60e7070..403ae51 100644 --- a/tests/completions_test.odin +++ b/tests/completions_test.odin @@ -687,11 +687,7 @@ ast_generic_make_completion :: proc(t: ^testing.T) { main = `package test make :: proc{ - make_dynamic_array, make_dynamic_array_len, - make_dynamic_array_len_cap, - make_map, - make_slice, }; make_slice :: proc($T: typeid/[]$E, #any_int len: int, loc := #caller_location) -> (T, Allocator_Error) #optional_second { } @@ -2560,7 +2556,7 @@ ast_poly_proc_matrix_whole :: proc(t: ^testing.T) { matrix4_from_trs_f16 :: proc "contextless" () -> matrix[4, 4]f32 { translation: matrix[4, 4]f32 rotation: matrix[4, 4]f32 - dsszz := matrix_mul(scale, translation) + dsszz := matrix_mul(rotation, translation) dssz{*} } `, @@ -2878,3 +2874,49 @@ ast_enumerated_array_index_completion :: proc(t: ^testing.T) { {"North", "East", "South", "West"}, ) } + +@(test) +ast_raw_data_slice :: proc(t: ^testing.T) { + source := test.Source { + main = `package main + _raw_data_slice :: proc(value: []$E) -> [^]E {} + _raw_data_dynamic :: proc(value: [dynamic]$E) -> [^]E {} + _raw_data_array :: proc(value: ^[$N]$E) -> [^]E {} + _raw_data_simd :: proc(value: ^#simd[$N]$E) -> [^]E {} + _raw_data_string :: proc(value: string) -> [^]byte {} + + _raw_data :: proc{_raw_data_slice, _raw_data_dynamic, _raw_data_array, _raw_data_simd, _raw_data_string} + + main :: proc() { + slice: []int + rezz := _raw_data(slice) + rez{*} + } + `, + } + + test.expect_completion_details(t, &source, "", {"test.rezz: [^]int"}) +} + +@(test) +ast_raw_data_slice_2 :: proc(t: ^testing.T) { + source := test.Source { + main = `package main + raw_data_slice :: proc(v: $T/[]$E) -> [^]E {} + + + cool :: proc { + raw_data_slice, + } + + main :: proc() { + my_slice: []int + rezz := cool(my_slice) + rez{*} + } + + `, + } + + test.expect_completion_details(t, &source, "", {"test.rezz: [^]int"}) +} |