From bb783f8f31797597ca0349434e236e6df923e14b Mon Sep 17 00:00:00 2001 From: Marc Vertes Date: Tue, 7 Nov 2023 21:49:01 +0100 Subject: parser: implement switch statement A VM instruction `EqualSet` has been added to preserve the left operand on the stack in case of failure, to allow efficient multiple tests on the same value. Both the pattern 'if/else if' and the classical case clauses have been implemented. --- lang/token.go | 1 + parser/README.md | 3 +- parser/compiler.go | 16 +++++---- parser/interpreter_test.go | 35 +++++++++++------- parser/parse.go | 90 ++++++++++++++++++++++++++++++++++++---------- vm/vm.go | 12 +++++++ 6 files changed, 120 insertions(+), 37 deletions(-) diff --git a/lang/token.go b/lang/token.go index 880f474..95a1be5 100644 --- a/lang/token.go +++ b/lang/token.go @@ -109,6 +109,7 @@ const ( JumpFalse JumpSetFalse JumpSetTrue + EqualSet ) func (t TokenId) IsKeyword() bool { return t >= Break && t <= Var } diff --git a/parser/README.md b/parser/README.md index c9ff440..0503e28 100644 --- a/parser/README.md +++ b/parser/README.md @@ -61,7 +61,8 @@ Go language support: - [ ] go statement - [x] if statement (including else and else if) - [x] for statement -- [ ] switch statement +- [x] switch statement +- [ ] type switch statement - [x] break statement - [x] continue statement - [ ] fallthrough statement diff --git a/parser/compiler.go b/parser/compiler.go index 6bf9312..e2b7823 100644 --- a/parser/compiler.go +++ b/parser/compiler.go @@ -110,6 +110,9 @@ func (c *Compiler) Codegen(tokens Tokens) (err error) { case lang.Equal: c.Emit(int64(t.Pos), vm.Equal) + case lang.EqualSet: + c.Emit(int64(t.Pos), vm.EqualSet) + case lang.Ident: if i < len(tokens)-1 { switch t1 := tokens[i+1]; t1.Id { @@ -228,12 +231,13 @@ func (c *Compiler) Codegen(tokens Tokens) (err error) { } func (c *Compiler) PrintCode() { - labels := map[int]string{} // labels indexed by code location - data := map[int]string{} // data indexed by frame location + labels := map[int][]string{} // labels indexed by code location + data := map[int]string{} // data indexed by frame location for name, sym := range c.symbols { if sym.kind == symLabel || sym.kind == symFunc { - labels[sym.value.(int)] = name + i := sym.value.(int) + labels[i] = append(labels[i], name) } if sym.used { data[sym.index] = name @@ -242,14 +246,14 @@ func (c *Compiler) PrintCode() { fmt.Println("# Code:") for i, l := range c.Code { - if label, ok := labels[i]; ok { + for _, label := range labels[i] { fmt.Println(label + ":") } extra := "" switch l[1] { case vm.Jump, vm.JumpFalse, vm.JumpTrue, vm.JumpSetFalse, vm.JumpSetTrue, vm.Calli: if d, ok := labels[i+(int)(l[2])]; ok { - extra = "// " + d + extra = "// " + d[0] } case vm.Dup, vm.Assign, vm.Fdup, vm.Fassign: if d, ok := data[int(l[2])]; ok { @@ -259,7 +263,7 @@ func (c *Compiler) PrintCode() { fmt.Printf("%4d %-14v %v\n", i, vm.CodeString(l), extra) } - if label, ok := labels[len(c.Code)]; ok { + for _, label := range labels[len(c.Code)] { fmt.Println(label + ":") } fmt.Println("# End code") diff --git a/parser/interpreter_test.go b/parser/interpreter_test.go index 02cd105..d73d282 100644 --- a/parser/interpreter_test.go +++ b/parser/interpreter_test.go @@ -125,22 +125,33 @@ f(3)`, res: "4"}, }) } -/* func TestSwitch(t *testing.T) { - run(t, []etest{{ - src: ` -func f(a int) int { + src0 := `func f(a int) int { switch a { - default: - a = 0 - case 1,2: - a = a+1 - case 3: - a = a+2 + default: a = 0 + case 1,2: a = a+1 + case 3: a = a+2 + } + return a +} +` + src1 := `func f(a int) int { + switch { + case a < 3: return 2 + case a < 5: return 5 + default: a = 0 } return a } -f(3)`, res: "5"}, +` + run(t, []etest{ + {src: src0 + "f(1)", res: "2"}, + {src: src0 + "f(2)", res: "3"}, + {src: src0 + "f(3)", res: "5"}, + {src: src0 + "f(4)", res: "0"}, + + {src: src1 + "f(1)", res: "2"}, + {src: src1 + "f(4)", res: "5"}, + {src: src1 + "f(6)", res: "0"}, }) } -*/ diff --git a/parser/parse.go b/parser/parse.go index fb8ae2c..be8b31a 100644 --- a/parser/parse.go +++ b/parser/parse.go @@ -84,7 +84,36 @@ func (p *Parser) Parse(src string) (out Tokens, err error) { if err != nil { return out, err } - log.Println("Parse in:", in) + return p.ParseStmts(in) + /* + log.Println("Parse in:", in) + for len(in) > 0 { + endstmt := in.Index(lang.Semicolon) + if endstmt == -1 { + return out, scanner.ErrBlock + } + // Skip over simple init statements for some tokens (if, for, ...) + if lang.HasInit[in[0].Id] { + for in[endstmt-1].Id != lang.BraceBlock { + e2 := in[endstmt+1:].Index(lang.Semicolon) + if e2 == -1 { + return out, scanner.ErrBlock + } + endstmt += 1 + e2 + } + } + o, err := p.ParseStmt(in[:endstmt]) + if err != nil { + return out, err + } + out = append(out, o...) + in = in[endstmt+1:] + } + return out, err + */ +} + +func (p *Parser) ParseStmts(in Tokens) (out Tokens, err error) { for len(in) > 0 { endstmt := in.Index(lang.Semicolon) if endstmt == -1 { @@ -111,7 +140,7 @@ func (p *Parser) Parse(src string) (out Tokens, err error) { } func (p *Parser) ParseStmt(in Tokens) (out Tokens, err error) { - log.Println("ParseStmt in:", in) + log.Println("ParseStmt in:", in, len(in)) if len(in) == 0 { return nil, nil } @@ -370,17 +399,16 @@ func (p *Parser) ParseSwitch(in Tokens) (out Tokens, err error) { } out = init } + condSwitch := false if len(cond) > 0 { if cond, err = p.ParseExpr(cond); err != nil { return nil, err } - } else { - cond = Tokens{{Id: lang.Ident, Str: "true"}} + out = append(out, cond...) + condSwitch = true } - out = append(out, cond...) // Split switch body into case clauses. clauses, err = p.Scan(in[len(in)-1].Block(), true) - log.Println("## clauses:", clauses) sc := clauses.SplitStart(lang.Case) // Make sure that the default clause is the last. lsc := len(sc) - 1 @@ -391,32 +419,58 @@ func (p *Parser) ParseSwitch(in Tokens) (out Tokens, err error) { } } // Process each clause. + nc := len(sc) - 1 for i, cl := range sc { - co, err := p.ParseCaseClause(cl, i) + co, err := p.ParseCaseClause(cl, i, nc, condSwitch) if err != nil { return nil, err } out = append(out, co...) } + out = append(out, scanner.Token{Id: lang.Label, Str: p.scope + "e"}) return out, err } -func (p *Parser) ParseCaseClause(in Tokens, index int) (out Tokens, err error) { - var initcond, init, cond, body Tokens +func (p *Parser) ParseCaseClause(in Tokens, index, max int, condSwitch bool) (out Tokens, err error) { + in = append(in, scanner.Token{Id: lang.Semicolon}) // Force a ';' at the end of body clause. + var conds, body Tokens tl := in.Split(lang.Colon) if len(tl) != 2 { return nil, errors.New("invalid case clause") } - initcond, body = tl[0][1:], tl[1] - if ii := initcond.Index(lang.Semicolon); ii < 0 { - cond = initcond - } else { - init = initcond[:ii] - cond = initcond[ii+1:] + conds = tl[0][1:] + if body, err = p.ParseStmts(tl[1]); err != nil { + return out, err + } + lcond := conds.Split(lang.Comma) + for i, cond := range lcond { + if cond, err = p.ParseExpr(cond); err != nil { + return out, err + } + txt := fmt.Sprintf("%sc%d.%d", p.scope, index, i) + next := "" + if i == len(lcond)-1 { // End of cond: next, go to next clause or exit + if index < max { + next = fmt.Sprintf("%sc%d.%d", p.scope, index+1, 0) + } else { + next = p.scope + "e" + } + } else { + next = fmt.Sprintf("%sc%d.%d", p.scope, index, i+1) + } + out = append(out, scanner.Token{Id: lang.Label, Str: txt}) + if len(cond) > 0 { + out = append(out, cond...) + if condSwitch { + out = append(out, scanner.Token{Id: lang.EqualSet}) + } + out = append(out, scanner.Token{Id: lang.JumpFalse, Str: "JumpFalse " + next}) + } + out = append(out, body...) + if i != len(lcond)-1 || index != max { + out = append(out, scanner.Token{Id: lang.Goto, Str: "Goto " + p.scope + "e"}) + } } - lcond := cond.Split(lang.Comma) - log.Println("# ParseCaseClause:", init, "cond:", cond, len(lcond)) - _ = body return out, err } diff --git a/vm/vm.go b/vm/vm.go index 0be9514..6f2df12 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -22,6 +22,7 @@ const ( Dup // addr -- value ; value = mem[addr] Fdup // addr -- value ; value = mem[addr] Equal // n1 n2 -- cond ; cond = n1 == n2 + EqualSet // n1 n2 -- n1 cond ; cond = n1 == n2 Exit // -- ; Greater // n1 n2 -- cond; cond = n1 > n2 Jump // -- ; ip += $1 @@ -48,6 +49,7 @@ var strop = [...]string{ // for VM tracing. CallX: "CallX", Dup: "Dup", Equal: "Equal", + EqualSet: "EqualSet", Exit: "Exit", Fassign: "Fassign", Fdup: "Fdup", @@ -141,6 +143,16 @@ func (m *Machine) Run() (err error) { case Equal: mem[sp-2] = mem[sp-2].(int) == mem[sp-1].(int) mem = mem[:sp-1] + case EqualSet: + if mem[sp-2].(int) == mem[sp-1].(int) { + // If equal then lhs and rhs are popped, replaced by test result, as in Equal. + mem[sp-2] = true + mem = mem[:sp-1] + } else { + // If not equal then the lhs is let on stack for further processing. + // This is used to simplify bytecode in case clauses of switch statments. + mem[sp-1] = false + } case Exit: return err case Fdup: -- cgit v1.2.3