diff options
| author | Nathaniel Saxe <NathanielSaxophone@gmail.com> | 2026-01-02 12:57:54 -0500 |
|---|---|---|
| committer | Nathaniel Saxe <NathanielSaxophone@gmail.com> | 2026-01-02 12:57:54 -0500 |
| commit | 0ad323e79c56a41e0e4818c646088cce983fade4 (patch) | |
| tree | 3a9298aa4f3a0779cc535754afc2432e027b3414 | |
| parent | 40706139690b2eac3d5321e57600d0f414cd179b (diff) | |
make it work for enum switch statements as well
| -rw-r--r-- | src/server/action.odin | 91 |
1 files changed, 66 insertions, 25 deletions
diff --git a/src/server/action.odin b/src/server/action.odin index e174f7e..72ef888 100644 --- a/src/server/action.odin +++ b/src/server/action.odin @@ -208,35 +208,71 @@ get_switch_cases_info :: proc( position_context: ^DocumentPositionContext, ) -> ( existing_cases: map[string]string, - enum_names: []string, + all_case_names: []string, + is_enum: bool, ok: bool, ) { - if switch_block, ok := position_context.switch_stmt.body.derived.(^ast.Block_Stmt); ok { - existing_cases = make(map[string]string, 5, context.temp_allocator) - for stmt in switch_block.stmts { - if case_clause, ok := stmt.derived.(^ast.Case_Clause); ok { - case_name := "" - for name in case_clause.list { + log.error(position_context.switch_stmt, position_context.switch_type_stmt) + if (position_context.switch_stmt == nil && position_context.switch_type_stmt == nil) || + (position_context.switch_stmt != nil && position_context.switch_stmt.cond == nil) { + return nil, nil, false, false + } + switch_block: ^ast.Block_Stmt + found_switch_block: bool + if position_context.switch_stmt != nil { + switch_block, found_switch_block = position_context.switch_stmt.body.derived.(^ast.Block_Stmt) + is_enum = true + } + if !found_switch_block && position_context.switch_type_stmt != nil { + switch_block, found_switch_block = position_context.switch_type_stmt.body.derived.(^ast.Block_Stmt) + } + if !found_switch_block { + return nil, nil, false, false + } + existing_cases = make(map[string]string, 5, context.temp_allocator) + for stmt in switch_block.stmts { + if case_clause, ok := stmt.derived.(^ast.Case_Clause); ok { + case_name := "" + for name in case_clause.list { + if is_enum { if implicit, ok := name.derived.(^ast.Implicit_Selector_Expr); ok { case_name = implicit.field.name break } + } else { + if ident, ok := name.derived.(^ast.Ident); ok { + case_name = ident.name + break + } } - if case_name != "" { - existing_cases[case_name] = get_block_original_text(case_clause.body, document.text) - } + } + if case_name != "" { + existing_cases[case_name] = get_block_original_text(case_clause.body, document.text) } } } - enum_value, _, unwrap_ok := unwrap_enum(ast_context, position_context.switch_stmt.cond) - if !unwrap_ok {return nil, nil, false} - return existing_cases, enum_value.names, true + log.error(existing_cases) + if is_enum { + enum_value, was_super_enum, unwrap_ok := unwrap_enum(ast_context, position_context.switch_stmt.cond) + if !unwrap_ok {return nil, nil, true, false} + return existing_cases, enum_value.names, !was_super_enum, true + } else { + st := position_context.switch_type_stmt + union_value, unwrap_ok := unwrap_union(ast_context, st.tag.derived.(^ast.Assign_Stmt).rhs[0]) + if !unwrap_ok {return nil, nil, false, false} + case_names := make([]string, len(union_value.types), context.temp_allocator) + for t, i in union_value.types { + case_names[i] = t.derived.(^ast.Ident).name + } + return existing_cases, case_names, false, true + } } create_populate_switch_cases_edit :: proc( position_context: ^DocumentPositionContext, existing_cases: map[string]string, - enum_names: []string, + all_case_names: []string, + is_enum: bool, ) -> ( TextEdit, bool, @@ -246,12 +282,18 @@ create_populate_switch_cases_edit :: proc( return {}, false } //entirety of the switch block - range := common.get_token_range(position_context.switch_stmt.body.stmt_base, position_context.file.src) + range: common.Range + if is_enum { + range = common.get_token_range(position_context.switch_stmt.body.stmt_base, position_context.file.src) + } else { + range = common.get_token_range(position_context.switch_type_stmt.body.stmt_base, position_context.file.src) + } replacement_builder := strings.builder_make() + dot := is_enum ? "." : "" b := &replacement_builder fmt.sbprintln(b, "{") - for name in enum_names { - fmt.sbprintln(b, "case .", name, ":", sep = "") + for name in all_case_names { + fmt.sbprintln(b, "case ", dot, name, ":", sep = "") if name in existing_cases { case_block := existing_cases[name] if case_block != "" { @@ -260,11 +302,10 @@ create_populate_switch_cases_edit :: proc( } } for name in existing_cases { - if !slice.contains(enum_names, name) { - //this case probably shouldn't exist - //since it's not one of the legal enum names, - //but don't delete the user's code inside the block - fmt.sbprintln(b, "case .", name, ":", sep = "") + if !slice.contains(all_case_names, name) { + //this case probably should be deleted by the user since it's not one of the legal enum names, + //but we shouldn't preemptively delete the user's code inside the block + fmt.sbprintln(b, "case ", dot, name, ":", sep = "") case_block := existing_cases[name] if case_block != "" { fmt.sbprintln(b, existing_cases[name]) @@ -281,16 +322,16 @@ add_populate_switch_cases_action :: proc( uri: string, actions: ^[dynamic]CodeAction, ) { - existing_cases, enum_names, ok := get_switch_cases_info(document, ast_context, position_context) + existing_cases, all_case_names, is_enum, ok := get_switch_cases_info(document, ast_context, position_context) if !ok {return} all_cases_covered := true - for name in enum_names { + for name in all_case_names { if name not_in existing_cases { all_cases_covered = false } } if all_cases_covered {return} //action not needed - edit, edit_ok := create_populate_switch_cases_edit(position_context, existing_cases, enum_names) + edit, edit_ok := create_populate_switch_cases_edit(position_context, existing_cases, all_case_names, is_enum) if !edit_ok {return} textEdits := make([dynamic]TextEdit, context.temp_allocator) append(&textEdits, edit) |