diff options
| author | DanielGavin <danielgavin5@hotmail.com> | 2023-12-29 17:43:10 +0100 |
|---|---|---|
| committer | DanielGavin <danielgavin5@hotmail.com> | 2023-12-29 17:43:10 +0100 |
| commit | e93830d576ea387cc7dc635306ce07f43c515853 (patch) | |
| tree | b810eadd2dfdbaa3a72f90e9db81ade7ca36ac06 | |
| parent | 589f609f76040c3be20adfa799d8b79345477882 (diff) | |
More poly work
| -rw-r--r-- | src/server/generics.odin | 93 | ||||
| -rw-r--r-- | src/server/symbol.odin | 7 | ||||
| -rw-r--r-- | tests/completions_test.odin | 78 |
3 files changed, 171 insertions, 7 deletions
diff --git a/src/server/generics.odin b/src/server/generics.odin index 5d07181..7141a75 100644 --- a/src/server/generics.odin +++ b/src/server/generics.odin @@ -57,6 +57,61 @@ resolve_poly :: proc( } #partial switch p in specialization.derived { + case ^ast.Matrix_Type: + if call_matrix, ok := call_node.derived.(^ast.Matrix_Type); ok { + found := false + if poly_type, ok := p.row_count.derived.(^ast.Poly_Type); ok { + if ident, ok := unwrap_ident(poly_type.type); ok { + poly_map[ident.name] = call_matrix.row_count + } + + if poly_type.specialization != nil { + return resolve_poly( + ast_context, + call_matrix.row_count, + call_symbol, + p.row_count, + poly_map, + ) + } + found |= true + } + + if poly_type, ok := p.column_count.derived.(^ast.Poly_Type); ok { + if ident, ok := unwrap_ident(poly_type.type); ok { + poly_map[ident.name] = call_matrix.column_count + } + + if poly_type.specialization != nil { + return resolve_poly( + ast_context, + call_matrix.column_count, + call_symbol, + p.column_count, + poly_map, + ) + } + found |= true + } + + if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok { + if ident, ok := unwrap_ident(poly_type.type); ok { + poly_map[ident.name] = call_matrix.elem + } + + if poly_type.specialization != nil { + return resolve_poly( + ast_context, + call_matrix.elem, + call_symbol, + p.elem, + poly_map, + ) + } + found |= true + } + return found + } case ^ast.Call_Expr: if call_struct, ok := call_node.derived.(^ast.Struct_Type); ok { arg_index := 0 @@ -78,7 +133,6 @@ resolve_poly :: proc( } } case ^ast.Struct_Type: - case ^ast.Dynamic_Array_Type: if call_array, ok := call_node.derived.(^ast.Dynamic_Array_Type); ok { if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok { @@ -100,6 +154,7 @@ resolve_poly :: proc( } case ^ast.Array_Type: if call_array, ok := call_node.derived.(^ast.Array_Type); ok { + found := false if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { poly_map[ident.name] = call_array.elem @@ -114,11 +169,31 @@ resolve_poly :: proc( poly_map, ) } - return true + found |= true } + + if poly_type, ok := p.len.derived.(^ast.Poly_Type); ok { + if ident, ok := unwrap_ident(poly_type.type); ok { + poly_map[ident.name] = call_array.len + } + + if poly_type.specialization != nil { + return resolve_poly( + ast_context, + call_array.len, + call_symbol, + p.len, + poly_map, + ) + } + found |= true + } + + return found } case ^ast.Map_Type: if call_map, ok := call_node.derived.(^ast.Map_Type); ok { + found := false if poly_type, ok := p.key.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { poly_map[ident.name] = call_map.key @@ -133,8 +208,10 @@ resolve_poly :: proc( poly_map, ) } - return true - } else if poly_type, ok := p.value.derived.(^ast.Poly_Type); ok { + found |= true + } + + if poly_type, ok := p.value.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { poly_map[ident.name] = call_map.value } @@ -148,8 +225,9 @@ resolve_poly :: proc( poly_map, ) } - return true + found |= true } + return found } case ^ast.Multi_Pointer_Type: if call_pointer, ok := call_node.derived.(^ast.Multi_Pointer_Type); @@ -224,7 +302,8 @@ find_and_replace_poly_type :: proc( case ^ast.Array_Type: if expr, ok := is_in_poly_map(v.elem, poly_map); ok { v.elem = expr - } else if expr, ok := is_in_poly_map(v.len, poly_map); ok { + } + if expr, ok := is_in_poly_map(v.len, poly_map); ok { v.len = expr } case ^ast.Multi_Pointer_Type: @@ -358,7 +437,7 @@ resolve_generic_function_symbol :: proc( } for k, v in poly_map { - //fmt.println(k, v.derived) + //fmt.println(k, v.derived, "\n") } if count_required_params > len(call_expr.args) || diff --git a/src/server/symbol.odin b/src/server/symbol.odin index a62c8cf..0f86d5c 100644 --- a/src/server/symbol.odin +++ b/src/server/symbol.odin @@ -314,6 +314,7 @@ symbol_to_expr :: proc( case SymbolFixedArrayValue: type := new_type(ast.Array_Type, pos, end, allocator) type.elem = v.expr + type.len = v.len return type case SymbolMapValue: type := new_type(ast.Map_Type, pos, end, allocator) @@ -332,6 +333,12 @@ symbol_to_expr :: proc( case SymbolUntypedValue: type := new_type(ast.Basic_Lit, pos, end, allocator) return type + case SymbolMatrixValue: + type := new_type(ast.Matrix_Type, pos, end, allocator) + type.row_count = v.x + type.column_count = v.y + type.elem = v.expr + return type case: log.errorf("Unhandled symbol %v", symbol) } diff --git a/tests/completions_test.odin b/tests/completions_test.odin index fc9575c..2e811a6 100644 --- a/tests/completions_test.odin +++ b/tests/completions_test.odin @@ -2494,3 +2494,81 @@ ast_poly_proc_array_constant :: proc(t: ^testing.T) { test.expect_completion_details(t, &source, "", {"test.array: [3]f32"}) } + +@(test) +ast_poly_proc_matrix_type :: proc(t: ^testing.T) { + packages := make([dynamic]test.Package) + + source := test.Source { + main = `package test + + matrix_to_ptr :: proc "contextless" (m: ^$A/matrix[$I, $J]$E) -> ^E { + return &m[0, 0] + } + + + main :: proc() { + my_matrix: matrix[2, 2]f32 + ptr := matrix_to_ptr(&my_matrix) + pt{*} + } + + `, + packages = packages[:], + } + + test.expect_completion_details(&t, &source, "", {"test.ptr: ^f32"}) +} + +@(test) +ast_poly_proc_matrix_constant_array :: proc(t: ^testing.T) { + packages := make([dynamic]test.Package) + + source := test.Source { + main = `package test + + matrix_to_ptr :: proc "contextless" (m: ^$A/matrix[$I, $J]$E) -> [J]E { + return {} + } + + main :: proc() { + my_matrix: matrix[4, 3]f32 + + ptr := matrix_to_ptr(&my_matrix) + pt{*} + } + `, + packages = packages[:], + } + + test.expect_completion_details(&t, &source, "", {"test.ptr: [3]f32"}) +} + +@(test) +ast_poly_proc_matrix_constant_array_2 :: proc(t: ^testing.T) { + packages := make([dynamic]test.Package) + + source := test.Source { + main = `package test + array_cast :: proc "contextless" ( + v: $A/[$N]$T, + $Elem_Type: typeid, + ) -> ( + w: [N]Elem_Type, + ) { + for i in 0 ..< N { + w[i] = Elem_Type(v[i]) + } + return + } + main :: proc() { + my_vector: [10]int + myss := array_cast(my_vector, f32) + mys{*} + } + `, + packages = packages[:], + } + + test.expect_completion_details(&t, &source, "", {"test.myss: [10]f32"}) +} |