aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDaniel Gavin <danielgavin5@hotmail.com>2022-10-27 16:21:12 +0200
committerDaniel Gavin <danielgavin5@hotmail.com>2022-10-27 16:21:12 +0200
commit601a1a447aee462575c8de91677991750cb612e5 (patch)
tree03e36f3406331675454b7069a0126c88e193ebba /src
parent42a1ceac0e9a080e11fa94d153c0dc5f3e24738b (diff)
Add support for matrix types
Diffstat (limited to 'src')
-rw-r--r--src/common/ast.odin2
-rw-r--r--src/server/analysis.odin145
-rw-r--r--src/server/collector.odin35
-rw-r--r--src/server/semantic_tokens.odin28
-rw-r--r--src/server/symbol.odin11
-rw-r--r--src/testing/testing.odin4
6 files changed, 213 insertions, 12 deletions
diff --git a/src/common/ast.odin b/src/common/ast.odin
index 2a7a027..732c736 100644
--- a/src/common/ast.odin
+++ b/src/common/ast.odin
@@ -1003,7 +1003,7 @@ repeat :: proc(
count: int,
allocator := context.allocator,
) -> string {
- if count == 0 {
+ if count <= 0 {
return ""
}
return strings.repeat(value, count, allocator)
diff --git a/src/server/analysis.odin b/src/server/analysis.odin
index 52a05f0..5e4076c 100644
--- a/src/server/analysis.odin
+++ b/src/server/analysis.odin
@@ -1142,6 +1142,13 @@ internal_resolve_type_expression :: proc(
ast_context.field_name,
),
true
+ case ^Matrix_Type:
+ return make_symbol_matrix_from_ast(
+ ast_context,
+ v^,
+ ast_context.field_name,
+ ),
+ true
case ^Dynamic_Array_Type:
return make_symbol_dynamic_array_from_ast(
ast_context,
@@ -1174,7 +1181,7 @@ internal_resolve_type_expression :: proc(
case ^Basic_Directive:
return resolve_basic_directive(ast_context, v^)
case ^Binary_Expr:
- return resolve_first_symbol_from_binary_expression(ast_context, v)
+ return resolve_binary_expression(ast_context, v)
case ^Ident:
delete_key(&ast_context.recursion_map, v)
return internal_resolve_type_identifier(ast_context, v^)
@@ -1246,6 +1253,13 @@ internal_resolve_type_expression :: proc(
symbol, ok := internal_resolve_type_expression(ast_context, v.elem)
symbol.pointers += 1
return symbol, ok
+ case ^Matrix_Index_Expr:
+ if symbol, ok := internal_resolve_type_expression(ast_context, v.expr);
+ ok {
+ if mat, ok := symbol.value.(SymbolMatrixValue); ok {
+ return internal_resolve_type_expression(ast_context, mat.expr)
+ }
+ }
case ^Index_Expr:
indexed, ok := internal_resolve_type_expression(ast_context, v.expr)
@@ -1663,6 +1677,9 @@ internal_resolve_type_identifier :: proc(
case ^Dynamic_Array_Type:
return_symbol, ok =
make_symbol_dynamic_array_from_ast(ast_context, v^, node), true
+ case ^Matrix_Type:
+ return_symbol, ok =
+ make_symbol_matrix_from_ast(ast_context, v^, node), true
case ^Map_Type:
return_symbol, ok =
make_symbol_map_from_ast(ast_context, v^, node), true
@@ -1760,6 +1777,9 @@ internal_resolve_type_identifier :: proc(
case ^Dynamic_Array_Type:
return_symbol, ok =
make_symbol_dynamic_array_from_ast(ast_context, v^, node), true
+ case ^Matrix_Type:
+ return_symbol, ok =
+ make_symbol_matrix_from_ast(ast_context, v^, node), true
case ^Map_Type:
return_symbol, ok =
make_symbol_map_from_ast(ast_context, v^, node), true
@@ -2113,10 +2133,7 @@ resolve_first_symbol_from_binary_expression :: proc(
Symbol,
bool,
) {
- //Fairly simple function to find the earliest identifier symbol in binary expression.
-
if binary.left != nil {
-
if ident, ok := binary.left.derived.(^ast.Ident); ok {
if s, ok := resolve_type_identifier(ast_context, ident^); ok {
return s, ok
@@ -2149,6 +2166,72 @@ resolve_first_symbol_from_binary_expression :: proc(
return {}, false
}
+resolve_binary_expression :: proc(
+ ast_context: ^AstContext,
+ binary: ^ast.Binary_Expr,
+) -> (
+ Symbol,
+ bool,
+) {
+ if binary.left == nil || binary.right == nil {
+ return {}, false
+ }
+
+ symbol_a, symbol_b: Symbol
+ ok_a, ok_b: bool
+
+ if expr, ok := binary.left.derived.(^ast.Binary_Expr); ok {
+ symbol_a, ok_a = resolve_binary_expression(ast_context, expr)
+ } else {
+ ast_context.use_locals = true
+ symbol_a, ok_a = resolve_type_expression(ast_context, binary.left)
+ }
+
+ if expr, ok := binary.right.derived.(^ast.Binary_Expr); ok {
+ symbol_b, ok_b = resolve_binary_expression(ast_context, expr)
+ } else {
+ ast_context.use_locals = true
+ symbol_b, ok_b = resolve_type_expression(ast_context, binary.right)
+ }
+
+ if !ok_a || !ok_b {
+ return {}, false
+ }
+
+ matrix_value_a, is_matrix_a := symbol_a.value.(SymbolMatrixValue)
+ matrix_value_b, is_matrix_b := symbol_b.value.(SymbolMatrixValue)
+
+ vector_value_a, is_vector_a := symbol_a.value.(SymbolFixedArrayValue)
+ vector_value_b, is_vector_b := symbol_b.value.(SymbolFixedArrayValue)
+
+ //Handle matrix multication specially because it can actual change the return type dimension
+ if is_matrix_a && is_matrix_b && binary.op.kind == .Mul {
+ symbol_a.value = SymbolMatrixValue {
+ expr = matrix_value_a.expr,
+ x = matrix_value_a.x,
+ y = matrix_value_b.y,
+ }
+ return symbol_a, true
+ } else if is_matrix_a && is_vector_b && binary.op.kind == .Mul {
+ symbol_a.value = SymbolFixedArrayValue {
+ expr = matrix_value_a.expr,
+ len = matrix_value_a.y,
+ }
+ return symbol_a, true
+
+ } else if is_vector_a && is_matrix_b && binary.op.kind == .Mul {
+ symbol_a.value = SymbolFixedArrayValue {
+ expr = matrix_value_b.expr,
+ len = matrix_value_b.x,
+ }
+ return symbol_a, true
+ }
+
+
+ //Otherwise just choose the first type, we do not handle error cases - that is done with the checker
+ return symbol_a, ok_a
+}
+
find_position_in_call_param :: proc(
ast_context: ^AstContext,
call: ast.Call_Expr,
@@ -2327,6 +2410,28 @@ make_symbol_dynamic_array_from_ast :: proc(
return symbol
}
+make_symbol_matrix_from_ast :: proc(
+ ast_context: ^AstContext,
+ v: ast.Matrix_Type,
+ name: ast.Ident,
+) -> Symbol {
+ symbol := Symbol {
+ range = common.get_token_range(v.node, ast_context.file.src),
+ type = .Constant,
+ pkg = get_package_from_node(v.node),
+ name = name.name,
+ }
+
+ symbol.value = SymbolMatrixValue {
+ expr = v.elem,
+ x = v.row_count,
+ y = v.column_count,
+ }
+
+ return symbol
+}
+
+
make_symbol_multi_pointer_from_ast :: proc(
ast_context: ^AstContext,
v: ast.Multi_Pointer_Type,
@@ -3865,13 +3970,19 @@ get_signature :: proc(
return "proc"
case SymbolStructValue:
if is_variable {
- return symbol.name
+ return strings.concatenate(
+ {pointer_prefix, symbol.name},
+ ast_context.allocator,
+ )
} else {
return "struct"
}
case SymbolUnionValue:
if is_variable {
- return symbol.name
+ return strings.concatenate(
+ {pointer_prefix, symbol.name},
+ ast_context.allocator,
+ )
} else {
return "union"
}
@@ -3901,6 +4012,20 @@ get_signature :: proc(
},
allocator = ast_context.allocator,
)
+ case SymbolMatrixValue:
+ return strings.concatenate(
+ a = {
+ pointer_prefix,
+ "matrix",
+ "[",
+ common.node_to_string(v.x),
+ ",",
+ common.node_to_string(v.y),
+ "]",
+ common.node_to_string(v.expr),
+ },
+ allocator = ast_context.allocator,
+ )
case SymbolPackageValue:
return "package"
case SymbolUntypedValue:
@@ -4553,6 +4678,14 @@ get_document_position_node :: proc(
}
case ^Undef:
case ^Basic_Lit:
+ case ^Matrix_Index_Expr:
+ get_document_position(n.expr, position_context)
+ get_document_position(n.row_index, position_context)
+ get_document_position(n.column_index, position_context)
+ case ^Matrix_Type:
+ get_document_position(n.row_count, position_context)
+ get_document_position(n.column_count, position_context)
+ get_document_position(n.elem, position_context)
case ^Ellipsis:
get_document_position(n.expr, position_context)
case ^Proc_Lit:
diff --git a/src/server/collector.odin b/src/server/collector.odin
index 9a7d050..b065a9e 100644
--- a/src/server/collector.odin
+++ b/src/server/collector.odin
@@ -320,6 +320,36 @@ collect_dynamic_array :: proc(
return SymbolDynamicArrayValue{expr = elem}
}
+collect_matrix :: proc(
+ collection: ^SymbolCollection,
+ mat: ast.Matrix_Type,
+ package_map: map[string]string,
+) -> SymbolMatrixValue {
+ elem := clone_type(
+ mat.elem,
+ collection.allocator,
+ &collection.unique_strings,
+ )
+
+ y := clone_type(
+ mat.column_count,
+ collection.allocator,
+ &collection.unique_strings,
+ )
+
+ x := clone_type(
+ mat.row_count,
+ collection.allocator,
+ &collection.unique_strings,
+ )
+
+ replace_package_alias(elem, package_map, collection)
+ replace_package_alias(x, package_map, collection)
+ replace_package_alias(y, package_map, collection)
+
+ return SymbolMatrixValue{expr = elem, x = x, y = y}
+}
+
collect_multi_pointer :: proc(
collection: ^SymbolCollection,
array: ast.Multi_Pointer_Type,
@@ -336,6 +366,7 @@ collect_multi_pointer :: proc(
return SymbolMultiPointer{expr = elem}
}
+
collect_generic :: proc(
collection: ^SymbolCollection,
expr: ^ast.Expr,
@@ -410,6 +441,10 @@ collect_symbols :: proc(
}
#partial switch v in col_expr.derived {
+ case ^ast.Matrix_Type:
+ token = v^
+ token_type = .Variable
+ symbol.value = collect_matrix(collection, v^, package_map)
case ^ast.Proc_Lit:
token = v^
token_type = .Function
diff --git a/src/server/semantic_tokens.odin b/src/server/semantic_tokens.odin
index f930530..8733be4 100644
--- a/src/server/semantic_tokens.odin
+++ b/src/server/semantic_tokens.odin
@@ -262,7 +262,6 @@ visit_node :: proc(
if symbol_and_node, ok := builder.symbols[cast(uintptr)node]; ok {
if .Distinct in symbol_and_node.symbol.flags &&
symbol_and_node.symbol.type == .Constant {
- log.error(symbol_and_node.symbol)
write_semantic_node(
builder,
node,
@@ -273,8 +272,7 @@ visit_node :: proc(
return
}
- if symbol_and_node.symbol.type == .Variable ||
- symbol_and_node.symbol.type == .Constant {
+ if symbol_and_node.symbol.type == .Variable {
write_semantic_node(
builder,
node,
@@ -350,6 +348,14 @@ visit_node :: proc(
.Type,
.None,
)
+ case SymbolMatrixValue:
+ write_semantic_node(
+ builder,
+ node,
+ ast_context.file.src,
+ .Type,
+ .None,
+ )
case:
//log.errorf("Unexpected symbol value: %v", symbol.value);
//panic(fmt.tprintf("Unexpected symbol value: %v", symbol.value));
@@ -386,6 +392,22 @@ visit_node :: proc(
visit(n.stmts, builder, ast_context)
case ^Expr_Stmt:
visit(n.expr, builder, ast_context)
+ case ^Matrix_Type:
+ write_semantic_string(
+ builder,
+ n.tok_pos,
+ "matrix",
+ ast_context.file.src,
+ .Keyword,
+ .None,
+ )
+ visit(n.row_count, builder, ast_context)
+ visit(n.column_count, builder, ast_context)
+ visit(n.elem, builder, ast_context)
+ case ^ast.Matrix_Index_Expr:
+ visit(n.expr, builder, ast_context)
+ visit(n.row_index, builder, ast_context)
+ visit(n.column_index, builder, ast_context)
case ^Branch_Stmt:
write_semantic_token(
builder,
diff --git a/src/server/symbol.odin b/src/server/symbol.odin
index 0c6a4dc..215e01c 100644
--- a/src/server/symbol.odin
+++ b/src/server/symbol.odin
@@ -89,6 +89,12 @@ SymbolMapValue :: struct {
value: ^ast.Expr,
}
+SymbolMatrixValue :: struct {
+ x: ^ast.Expr,
+ y: ^ast.Expr,
+ expr: ^ast.Expr,
+}
+
/*
Generic symbol that is used by the indexer for any variable type(constants, defined global variables, etc),
*/
@@ -113,6 +119,7 @@ SymbolValue :: union {
SymbolSliceValue,
SymbolBasicValue,
SymbolUntypedValue,
+ SymbolMatrixValue,
}
SymbolFlag :: enum {
@@ -178,6 +185,10 @@ free_symbol :: proc(symbol: Symbol, allocator: mem.Allocator) {
}
switch v in symbol.value {
+ case SymbolMatrixValue:
+ common.free_ast(v.expr, allocator)
+ common.free_ast(v.x, allocator)
+ common.free_ast(v.y, allocator)
case SymbolMultiPointer:
common.free_ast(v.expr, allocator)
case SymbolProcedureValue:
diff --git a/src/testing/testing.odin b/src/testing/testing.odin
index e4ac86d..0d97870 100644
--- a/src/testing/testing.odin
+++ b/src/testing/testing.odin
@@ -58,9 +58,9 @@ setup :: proc(src: ^Source) {
} else if current == '\n' {
current_line += 1
current_character = 0
- } else if current == '*' {
+ } else if src.main[current_index:current_index + 3] == "{*}" {
dst_slice := transmute([]u8)src.main[current_index:]
- src_slice := transmute([]u8)src.main[current_index + 1:]
+ src_slice := transmute([]u8)src.main[current_index + 3:]
copy(dst_slice, src_slice)
src.position.character = current_character
src.position.line = current_line