aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/common/ast.odin29
-rw-r--r--src/server/analysis.odin72
-rw-r--r--src/server/generics.odin67
-rw-r--r--tests/completions_test.odin52
4 files changed, 158 insertions, 62 deletions
diff --git a/src/common/ast.odin b/src/common/ast.odin
index 29d3e6c..b215eb8 100644
--- a/src/common/ast.odin
+++ b/src/common/ast.odin
@@ -207,6 +207,35 @@ unwrap_pointer_expr :: proc(expr: ^ast.Expr) -> (^ast.Expr, int, bool) {
return expr, n, true
}
+expr_contains_poly :: proc(expr: ^ast.Expr) -> bool {
+ if expr == nil {
+ return false
+ }
+
+ visit :: proc(visitor: ^ast.Visitor, node: ^ast.Node) -> ^ast.Visitor {
+ if node == nil {
+ return nil
+ }
+ if _, ok := node.derived.(^ast.Poly_Type); ok {
+ b := cast(^bool)visitor.data
+ b^ = true
+ return nil
+ }
+ return visitor
+ }
+
+ found := false
+
+ visitor := ast.Visitor {
+ visit = visit,
+ data = &found,
+ }
+
+ ast.walk(&visitor, expr)
+
+ return found
+}
+
is_expr_basic_lit :: proc(expr: ^ast.Expr) -> bool {
_, ok := expr.derived.(^ast.Basic_Lit)
return ok
diff --git a/src/server/analysis.odin b/src/server/analysis.odin
index d6592ae..f111c5d 100644
--- a/src/server/analysis.odin
+++ b/src/server/analysis.odin
@@ -101,6 +101,7 @@ AstContext :: struct {
uri: string,
fullpath: string,
non_mutable_only: bool,
+ overloading: bool,
}
make_ast_context :: proc(
@@ -617,6 +618,13 @@ resolve_function_overload :: proc(
Symbol,
bool,
) {
+ old_overloading := ast_context.overloading
+ ast_context.overloading = true
+
+ defer {
+ ast_context.overloading = old_overloading
+ }
+
using ast
call_expr := ast_context.call
@@ -1445,22 +1453,10 @@ internal_resolve_type_identifier :: proc(
make_symbol_bit_field_from_ast(ast_context, v^, node), true
return_symbol.name = node.name
case ^Proc_Lit:
- if !is_procedure_generic(v.type) {
- return_symbol, ok =
- make_symbol_procedure_from_ast(
- ast_context,
- local.rhs,
- v.type^,
- node,
- {},
- false,
- ),
- true
- } else {
- if return_symbol, ok = resolve_generic_function(
- ast_context,
- v^,
- ); !ok {
+ if is_procedure_generic(v.type) {
+ return_symbol, ok = resolve_generic_function(ast_context, v^)
+
+ if !ok && !ast_context.overloading {
return_symbol, ok =
make_symbol_procedure_from_ast(
ast_context,
@@ -1472,6 +1468,18 @@ internal_resolve_type_identifier :: proc(
),
true
}
+ } else {
+
+ return_symbol, ok =
+ make_symbol_procedure_from_ast(
+ ast_context,
+ local.rhs,
+ v.type^,
+ node,
+ {},
+ false,
+ ),
+ true
}
case ^Proc_Group:
return_symbol, ok = resolve_function_overload(ast_context, v^)
@@ -1581,22 +1589,11 @@ internal_resolve_type_identifier :: proc(
make_symbol_bit_field_from_ast(ast_context, v^, node), true
return_symbol.name = node.name
case ^Proc_Lit:
- if !is_procedure_generic(v.type) {
- return_symbol, ok =
- make_symbol_procedure_from_ast(
- ast_context,
- global.expr,
- v.type^,
- node,
- global.attributes,
- false,
- ),
- true
- } else {
- if return_symbol, ok = resolve_generic_function(
- ast_context,
- v^,
- ); !ok {
+ if is_procedure_generic(v.type) {
+ return_symbol, ok = resolve_generic_function(ast_context, v^)
+
+ //If we are not overloading just show the unresolved generic function
+ if !ok && !ast_context.overloading {
return_symbol, ok =
make_symbol_procedure_from_ast(
ast_context,
@@ -1608,6 +1605,17 @@ internal_resolve_type_identifier :: proc(
),
true
}
+ } else {
+ return_symbol, ok =
+ make_symbol_procedure_from_ast(
+ ast_context,
+ global.expr,
+ v.type^,
+ node,
+ global.attributes,
+ false,
+ ),
+ true
}
case ^Proc_Group:
return_symbol, ok = resolve_function_overload(ast_context, v^)
diff --git a/src/server/generics.odin b/src/server/generics.odin
index 9e915c3..5ff5997 100644
--- a/src/server/generics.odin
+++ b/src/server/generics.odin
@@ -40,6 +40,8 @@ resolve_poly :: proc(
case ^ast.Poly_Type:
specialization = v.specialization
type = v.type
+ case:
+ specialization = poly_node
}
if specialization == nil {
@@ -124,7 +126,7 @@ resolve_poly :: proc(
if call_struct, ok := call_node.derived.(^ast.Struct_Type); ok {
arg_index := 0
struct_value := call_symbol.value.(SymbolStructValue)
-
+ found := false
for arg in p.args {
if poly_type, ok := arg.derived.(^ast.Poly_Type); ok {
if poly_type.type == nil ||
@@ -140,8 +142,11 @@ resolve_poly :: proc(
)
arg_index += 1
+ found |= true
}
}
+
+ return found
}
case ^ast.Struct_Type:
case ^ast.Dynamic_Array_Type:
@@ -296,14 +301,13 @@ resolve_poly :: proc(
poly_map,
)
}
+ return true
}
}
case ^ast.Ident:
- if n, ok := call_node.derived.(^ast.Ident); ok {
- return true
- }
+ return true
case:
- log.error("Unhandled specialization %v", specialization.derived)
+ return false
}
return false
@@ -585,13 +589,18 @@ resolve_generic_function_symbol :: proc(
nil,
)
}
+ } else {
+ return {}, false
}
+ } else {
+ return {}, false
}
i += 1
}
}
+
for k, v in poly_map {
find_and_replace_poly_type(v, &poly_map)
}
@@ -643,28 +652,39 @@ resolve_generic_function_symbol :: proc(
append(&return_types, field)
}
+
for param in params {
- if len(param.names) == 0 {
- continue
+ field := cast(^ast.Field)clone_node(param, ast_context.allocator, nil)
+
+ if field.type != nil {
+ if poly_type, ok := field.type.derived.(^ast.Poly_Type); ok {
+ if expr, ok := poly_map[poly_type.type.name]; ok {
+ field.type = expr
+ }
+ } else {
+ if ident, ok := unwrap_ident(field.type); ok {
+ if expr, ok := poly_map[ident.name]; ok {
+ field.type = expr
+ }
+ }
+
+ find_and_replace_poly_type(field.type, &poly_map)
+ }
}
- //check the name for poly
- if poly_type, ok := param.names[0].derived.(^ast.Poly_Type);
- ok && param.type != nil {
- if m, ok := poly_map[poly_type.type.name]; ok {
- field := cast(^ast.Field)clone_node(
- param,
- ast_context.allocator,
- nil,
- )
- field.type = m
- append(&argument_types, field)
+ if len(param.names) > 0 {
+ if poly_type, ok := param.names[0].derived.(^ast.Poly_Type);
+ ok && param.type != nil {
+ if m, ok := poly_map[poly_type.type.name]; ok {
+ field.type = m
+ }
}
- } else {
- append(&argument_types, param)
}
+
+ append(&argument_types, field)
}
+
symbol.value = SymbolProcedureValue {
return_types = return_types[:],
arg_types = argument_types[:],
@@ -683,12 +703,9 @@ is_procedure_generic :: proc(proc_type: ^ast.Proc_Type) -> bool {
continue
}
- if expr, _, ok := common.unwrap_pointer_expr(param.type); ok {
- if _, ok := expr.derived.(^ast.Poly_Type); ok {
- return true
- }
+ if common.expr_contains_poly(param.type) {
+ return true
}
-
}
return false
diff --git a/tests/completions_test.odin b/tests/completions_test.odin
index 60e7070..403ae51 100644
--- a/tests/completions_test.odin
+++ b/tests/completions_test.odin
@@ -687,11 +687,7 @@ ast_generic_make_completion :: proc(t: ^testing.T) {
main = `package test
make :: proc{
- make_dynamic_array,
make_dynamic_array_len,
- make_dynamic_array_len_cap,
- make_map,
- make_slice,
};
make_slice :: proc($T: typeid/[]$E, #any_int len: int, loc := #caller_location) -> (T, Allocator_Error) #optional_second {
}
@@ -2560,7 +2556,7 @@ ast_poly_proc_matrix_whole :: proc(t: ^testing.T) {
matrix4_from_trs_f16 :: proc "contextless" () -> matrix[4, 4]f32 {
translation: matrix[4, 4]f32
rotation: matrix[4, 4]f32
- dsszz := matrix_mul(scale, translation)
+ dsszz := matrix_mul(rotation, translation)
dssz{*}
}
`,
@@ -2878,3 +2874,49 @@ ast_enumerated_array_index_completion :: proc(t: ^testing.T) {
{"North", "East", "South", "West"},
)
}
+
+@(test)
+ast_raw_data_slice :: proc(t: ^testing.T) {
+ source := test.Source {
+ main = `package main
+ _raw_data_slice :: proc(value: []$E) -> [^]E {}
+ _raw_data_dynamic :: proc(value: [dynamic]$E) -> [^]E {}
+ _raw_data_array :: proc(value: ^[$N]$E) -> [^]E {}
+ _raw_data_simd :: proc(value: ^#simd[$N]$E) -> [^]E {}
+ _raw_data_string :: proc(value: string) -> [^]byte {}
+
+ _raw_data :: proc{_raw_data_slice, _raw_data_dynamic, _raw_data_array, _raw_data_simd, _raw_data_string}
+
+ main :: proc() {
+ slice: []int
+ rezz := _raw_data(slice)
+ rez{*}
+ }
+ `,
+ }
+
+ test.expect_completion_details(t, &source, "", {"test.rezz: [^]int"})
+}
+
+@(test)
+ast_raw_data_slice_2 :: proc(t: ^testing.T) {
+ source := test.Source {
+ main = `package main
+ raw_data_slice :: proc(v: $T/[]$E) -> [^]E {}
+
+
+ cool :: proc {
+ raw_data_slice,
+ }
+
+ main :: proc() {
+ my_slice: []int
+ rezz := cool(my_slice)
+ rez{*}
+ }
+
+ `,
+ }
+
+ test.expect_completion_details(t, &source, "", {"test.rezz: [^]int"})
+}