diff options
| author | Daniel Gavin <danielgavin5@hotmail.com> | 2022-10-27 16:21:12 +0200 |
|---|---|---|
| committer | Daniel Gavin <danielgavin5@hotmail.com> | 2022-10-27 16:21:12 +0200 |
| commit | 601a1a447aee462575c8de91677991750cb612e5 (patch) | |
| tree | 03e36f3406331675454b7069a0126c88e193ebba /src/server | |
| parent | 42a1ceac0e9a080e11fa94d153c0dc5f3e24738b (diff) | |
Add support for matrix types
Diffstat (limited to 'src/server')
| -rw-r--r-- | src/server/analysis.odin | 145 | ||||
| -rw-r--r-- | src/server/collector.odin | 35 | ||||
| -rw-r--r-- | src/server/semantic_tokens.odin | 28 | ||||
| -rw-r--r-- | src/server/symbol.odin | 11 |
4 files changed, 210 insertions, 9 deletions
diff --git a/src/server/analysis.odin b/src/server/analysis.odin index 52a05f0..5e4076c 100644 --- a/src/server/analysis.odin +++ b/src/server/analysis.odin @@ -1142,6 +1142,13 @@ internal_resolve_type_expression :: proc( ast_context.field_name, ), true + case ^Matrix_Type: + return make_symbol_matrix_from_ast( + ast_context, + v^, + ast_context.field_name, + ), + true case ^Dynamic_Array_Type: return make_symbol_dynamic_array_from_ast( ast_context, @@ -1174,7 +1181,7 @@ internal_resolve_type_expression :: proc( case ^Basic_Directive: return resolve_basic_directive(ast_context, v^) case ^Binary_Expr: - return resolve_first_symbol_from_binary_expression(ast_context, v) + return resolve_binary_expression(ast_context, v) case ^Ident: delete_key(&ast_context.recursion_map, v) return internal_resolve_type_identifier(ast_context, v^) @@ -1246,6 +1253,13 @@ internal_resolve_type_expression :: proc( symbol, ok := internal_resolve_type_expression(ast_context, v.elem) symbol.pointers += 1 return symbol, ok + case ^Matrix_Index_Expr: + if symbol, ok := internal_resolve_type_expression(ast_context, v.expr); + ok { + if mat, ok := symbol.value.(SymbolMatrixValue); ok { + return internal_resolve_type_expression(ast_context, mat.expr) + } + } case ^Index_Expr: indexed, ok := internal_resolve_type_expression(ast_context, v.expr) @@ -1663,6 +1677,9 @@ internal_resolve_type_identifier :: proc( case ^Dynamic_Array_Type: return_symbol, ok = make_symbol_dynamic_array_from_ast(ast_context, v^, node), true + case ^Matrix_Type: + return_symbol, ok = + make_symbol_matrix_from_ast(ast_context, v^, node), true case ^Map_Type: return_symbol, ok = make_symbol_map_from_ast(ast_context, v^, node), true @@ -1760,6 +1777,9 @@ internal_resolve_type_identifier :: proc( case ^Dynamic_Array_Type: return_symbol, ok = make_symbol_dynamic_array_from_ast(ast_context, v^, node), true + case ^Matrix_Type: + return_symbol, ok = + make_symbol_matrix_from_ast(ast_context, v^, node), true case ^Map_Type: return_symbol, ok = make_symbol_map_from_ast(ast_context, v^, node), true @@ -2113,10 +2133,7 @@ resolve_first_symbol_from_binary_expression :: proc( Symbol, bool, ) { - //Fairly simple function to find the earliest identifier symbol in binary expression. - if binary.left != nil { - if ident, ok := binary.left.derived.(^ast.Ident); ok { if s, ok := resolve_type_identifier(ast_context, ident^); ok { return s, ok @@ -2149,6 +2166,72 @@ resolve_first_symbol_from_binary_expression :: proc( return {}, false } +resolve_binary_expression :: proc( + ast_context: ^AstContext, + binary: ^ast.Binary_Expr, +) -> ( + Symbol, + bool, +) { + if binary.left == nil || binary.right == nil { + return {}, false + } + + symbol_a, symbol_b: Symbol + ok_a, ok_b: bool + + if expr, ok := binary.left.derived.(^ast.Binary_Expr); ok { + symbol_a, ok_a = resolve_binary_expression(ast_context, expr) + } else { + ast_context.use_locals = true + symbol_a, ok_a = resolve_type_expression(ast_context, binary.left) + } + + if expr, ok := binary.right.derived.(^ast.Binary_Expr); ok { + symbol_b, ok_b = resolve_binary_expression(ast_context, expr) + } else { + ast_context.use_locals = true + symbol_b, ok_b = resolve_type_expression(ast_context, binary.right) + } + + if !ok_a || !ok_b { + return {}, false + } + + matrix_value_a, is_matrix_a := symbol_a.value.(SymbolMatrixValue) + matrix_value_b, is_matrix_b := symbol_b.value.(SymbolMatrixValue) + + vector_value_a, is_vector_a := symbol_a.value.(SymbolFixedArrayValue) + vector_value_b, is_vector_b := symbol_b.value.(SymbolFixedArrayValue) + + //Handle matrix multication specially because it can actual change the return type dimension + if is_matrix_a && is_matrix_b && binary.op.kind == .Mul { + symbol_a.value = SymbolMatrixValue { + expr = matrix_value_a.expr, + x = matrix_value_a.x, + y = matrix_value_b.y, + } + return symbol_a, true + } else if is_matrix_a && is_vector_b && binary.op.kind == .Mul { + symbol_a.value = SymbolFixedArrayValue { + expr = matrix_value_a.expr, + len = matrix_value_a.y, + } + return symbol_a, true + + } else if is_vector_a && is_matrix_b && binary.op.kind == .Mul { + symbol_a.value = SymbolFixedArrayValue { + expr = matrix_value_b.expr, + len = matrix_value_b.x, + } + return symbol_a, true + } + + + //Otherwise just choose the first type, we do not handle error cases - that is done with the checker + return symbol_a, ok_a +} + find_position_in_call_param :: proc( ast_context: ^AstContext, call: ast.Call_Expr, @@ -2327,6 +2410,28 @@ make_symbol_dynamic_array_from_ast :: proc( return symbol } +make_symbol_matrix_from_ast :: proc( + ast_context: ^AstContext, + v: ast.Matrix_Type, + name: ast.Ident, +) -> Symbol { + symbol := Symbol { + range = common.get_token_range(v.node, ast_context.file.src), + type = .Constant, + pkg = get_package_from_node(v.node), + name = name.name, + } + + symbol.value = SymbolMatrixValue { + expr = v.elem, + x = v.row_count, + y = v.column_count, + } + + return symbol +} + + make_symbol_multi_pointer_from_ast :: proc( ast_context: ^AstContext, v: ast.Multi_Pointer_Type, @@ -3865,13 +3970,19 @@ get_signature :: proc( return "proc" case SymbolStructValue: if is_variable { - return symbol.name + return strings.concatenate( + {pointer_prefix, symbol.name}, + ast_context.allocator, + ) } else { return "struct" } case SymbolUnionValue: if is_variable { - return symbol.name + return strings.concatenate( + {pointer_prefix, symbol.name}, + ast_context.allocator, + ) } else { return "union" } @@ -3901,6 +4012,20 @@ get_signature :: proc( }, allocator = ast_context.allocator, ) + case SymbolMatrixValue: + return strings.concatenate( + a = { + pointer_prefix, + "matrix", + "[", + common.node_to_string(v.x), + ",", + common.node_to_string(v.y), + "]", + common.node_to_string(v.expr), + }, + allocator = ast_context.allocator, + ) case SymbolPackageValue: return "package" case SymbolUntypedValue: @@ -4553,6 +4678,14 @@ get_document_position_node :: proc( } case ^Undef: case ^Basic_Lit: + case ^Matrix_Index_Expr: + get_document_position(n.expr, position_context) + get_document_position(n.row_index, position_context) + get_document_position(n.column_index, position_context) + case ^Matrix_Type: + get_document_position(n.row_count, position_context) + get_document_position(n.column_count, position_context) + get_document_position(n.elem, position_context) case ^Ellipsis: get_document_position(n.expr, position_context) case ^Proc_Lit: diff --git a/src/server/collector.odin b/src/server/collector.odin index 9a7d050..b065a9e 100644 --- a/src/server/collector.odin +++ b/src/server/collector.odin @@ -320,6 +320,36 @@ collect_dynamic_array :: proc( return SymbolDynamicArrayValue{expr = elem} } +collect_matrix :: proc( + collection: ^SymbolCollection, + mat: ast.Matrix_Type, + package_map: map[string]string, +) -> SymbolMatrixValue { + elem := clone_type( + mat.elem, + collection.allocator, + &collection.unique_strings, + ) + + y := clone_type( + mat.column_count, + collection.allocator, + &collection.unique_strings, + ) + + x := clone_type( + mat.row_count, + collection.allocator, + &collection.unique_strings, + ) + + replace_package_alias(elem, package_map, collection) + replace_package_alias(x, package_map, collection) + replace_package_alias(y, package_map, collection) + + return SymbolMatrixValue{expr = elem, x = x, y = y} +} + collect_multi_pointer :: proc( collection: ^SymbolCollection, array: ast.Multi_Pointer_Type, @@ -336,6 +366,7 @@ collect_multi_pointer :: proc( return SymbolMultiPointer{expr = elem} } + collect_generic :: proc( collection: ^SymbolCollection, expr: ^ast.Expr, @@ -410,6 +441,10 @@ collect_symbols :: proc( } #partial switch v in col_expr.derived { + case ^ast.Matrix_Type: + token = v^ + token_type = .Variable + symbol.value = collect_matrix(collection, v^, package_map) case ^ast.Proc_Lit: token = v^ token_type = .Function diff --git a/src/server/semantic_tokens.odin b/src/server/semantic_tokens.odin index f930530..8733be4 100644 --- a/src/server/semantic_tokens.odin +++ b/src/server/semantic_tokens.odin @@ -262,7 +262,6 @@ visit_node :: proc( if symbol_and_node, ok := builder.symbols[cast(uintptr)node]; ok { if .Distinct in symbol_and_node.symbol.flags && symbol_and_node.symbol.type == .Constant { - log.error(symbol_and_node.symbol) write_semantic_node( builder, node, @@ -273,8 +272,7 @@ visit_node :: proc( return } - if symbol_and_node.symbol.type == .Variable || - symbol_and_node.symbol.type == .Constant { + if symbol_and_node.symbol.type == .Variable { write_semantic_node( builder, node, @@ -350,6 +348,14 @@ visit_node :: proc( .Type, .None, ) + case SymbolMatrixValue: + write_semantic_node( + builder, + node, + ast_context.file.src, + .Type, + .None, + ) case: //log.errorf("Unexpected symbol value: %v", symbol.value); //panic(fmt.tprintf("Unexpected symbol value: %v", symbol.value)); @@ -386,6 +392,22 @@ visit_node :: proc( visit(n.stmts, builder, ast_context) case ^Expr_Stmt: visit(n.expr, builder, ast_context) + case ^Matrix_Type: + write_semantic_string( + builder, + n.tok_pos, + "matrix", + ast_context.file.src, + .Keyword, + .None, + ) + visit(n.row_count, builder, ast_context) + visit(n.column_count, builder, ast_context) + visit(n.elem, builder, ast_context) + case ^ast.Matrix_Index_Expr: + visit(n.expr, builder, ast_context) + visit(n.row_index, builder, ast_context) + visit(n.column_index, builder, ast_context) case ^Branch_Stmt: write_semantic_token( builder, diff --git a/src/server/symbol.odin b/src/server/symbol.odin index 0c6a4dc..215e01c 100644 --- a/src/server/symbol.odin +++ b/src/server/symbol.odin @@ -89,6 +89,12 @@ SymbolMapValue :: struct { value: ^ast.Expr, } +SymbolMatrixValue :: struct { + x: ^ast.Expr, + y: ^ast.Expr, + expr: ^ast.Expr, +} + /* Generic symbol that is used by the indexer for any variable type(constants, defined global variables, etc), */ @@ -113,6 +119,7 @@ SymbolValue :: union { SymbolSliceValue, SymbolBasicValue, SymbolUntypedValue, + SymbolMatrixValue, } SymbolFlag :: enum { @@ -178,6 +185,10 @@ free_symbol :: proc(symbol: Symbol, allocator: mem.Allocator) { } switch v in symbol.value { + case SymbolMatrixValue: + common.free_ast(v.expr, allocator) + common.free_ast(v.x, allocator) + common.free_ast(v.y, allocator) case SymbolMultiPointer: common.free_ast(v.expr, allocator) case SymbolProcedureValue: |