summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lang/token.go1
-rw-r--r--parser/README.md3
-rw-r--r--parser/compiler.go16
-rw-r--r--parser/interpreter_test.go35
-rw-r--r--parser/parse.go90
-rw-r--r--vm/vm.go12
6 files changed, 120 insertions, 37 deletions
diff --git a/lang/token.go b/lang/token.go
index 880f474..95a1be5 100644
--- a/lang/token.go
+++ b/lang/token.go
@@ -109,6 +109,7 @@ const (
JumpFalse
JumpSetFalse
JumpSetTrue
+ EqualSet
)
func (t TokenId) IsKeyword() bool { return t >= Break && t <= Var }
diff --git a/parser/README.md b/parser/README.md
index c9ff440..0503e28 100644
--- a/parser/README.md
+++ b/parser/README.md
@@ -61,7 +61,8 @@ Go language support:
- [ ] go statement
- [x] if statement (including else and else if)
- [x] for statement
-- [ ] switch statement
+- [x] switch statement
+- [ ] type switch statement
- [x] break statement
- [x] continue statement
- [ ] fallthrough statement
diff --git a/parser/compiler.go b/parser/compiler.go
index 6bf9312..e2b7823 100644
--- a/parser/compiler.go
+++ b/parser/compiler.go
@@ -110,6 +110,9 @@ func (c *Compiler) Codegen(tokens Tokens) (err error) {
case lang.Equal:
c.Emit(int64(t.Pos), vm.Equal)
+ case lang.EqualSet:
+ c.Emit(int64(t.Pos), vm.EqualSet)
+
case lang.Ident:
if i < len(tokens)-1 {
switch t1 := tokens[i+1]; t1.Id {
@@ -228,12 +231,13 @@ func (c *Compiler) Codegen(tokens Tokens) (err error) {
}
func (c *Compiler) PrintCode() {
- labels := map[int]string{} // labels indexed by code location
- data := map[int]string{} // data indexed by frame location
+ labels := map[int][]string{} // labels indexed by code location
+ data := map[int]string{} // data indexed by frame location
for name, sym := range c.symbols {
if sym.kind == symLabel || sym.kind == symFunc {
- labels[sym.value.(int)] = name
+ i := sym.value.(int)
+ labels[i] = append(labels[i], name)
}
if sym.used {
data[sym.index] = name
@@ -242,14 +246,14 @@ func (c *Compiler) PrintCode() {
fmt.Println("# Code:")
for i, l := range c.Code {
- if label, ok := labels[i]; ok {
+ for _, label := range labels[i] {
fmt.Println(label + ":")
}
extra := ""
switch l[1] {
case vm.Jump, vm.JumpFalse, vm.JumpTrue, vm.JumpSetFalse, vm.JumpSetTrue, vm.Calli:
if d, ok := labels[i+(int)(l[2])]; ok {
- extra = "// " + d
+ extra = "// " + d[0]
}
case vm.Dup, vm.Assign, vm.Fdup, vm.Fassign:
if d, ok := data[int(l[2])]; ok {
@@ -259,7 +263,7 @@ func (c *Compiler) PrintCode() {
fmt.Printf("%4d %-14v %v\n", i, vm.CodeString(l), extra)
}
- if label, ok := labels[len(c.Code)]; ok {
+ for _, label := range labels[len(c.Code)] {
fmt.Println(label + ":")
}
fmt.Println("# End code")
diff --git a/parser/interpreter_test.go b/parser/interpreter_test.go
index 02cd105..d73d282 100644
--- a/parser/interpreter_test.go
+++ b/parser/interpreter_test.go
@@ -125,22 +125,33 @@ f(3)`, res: "4"},
})
}
-/*
func TestSwitch(t *testing.T) {
- run(t, []etest{{
- src: `
-func f(a int) int {
+ src0 := `func f(a int) int {
switch a {
- default:
- a = 0
- case 1,2:
- a = a+1
- case 3:
- a = a+2
+ default: a = 0
+ case 1,2: a = a+1
+ case 3: a = a+2
+ }
+ return a
+}
+`
+ src1 := `func f(a int) int {
+ switch {
+ case a < 3: return 2
+ case a < 5: return 5
+ default: a = 0
}
return a
}
-f(3)`, res: "5"},
+`
+ run(t, []etest{
+ {src: src0 + "f(1)", res: "2"},
+ {src: src0 + "f(2)", res: "3"},
+ {src: src0 + "f(3)", res: "5"},
+ {src: src0 + "f(4)", res: "0"},
+
+ {src: src1 + "f(1)", res: "2"},
+ {src: src1 + "f(4)", res: "5"},
+ {src: src1 + "f(6)", res: "0"},
})
}
-*/
diff --git a/parser/parse.go b/parser/parse.go
index fb8ae2c..be8b31a 100644
--- a/parser/parse.go
+++ b/parser/parse.go
@@ -84,7 +84,36 @@ func (p *Parser) Parse(src string) (out Tokens, err error) {
if err != nil {
return out, err
}
- log.Println("Parse in:", in)
+ return p.ParseStmts(in)
+ /*
+ log.Println("Parse in:", in)
+ for len(in) > 0 {
+ endstmt := in.Index(lang.Semicolon)
+ if endstmt == -1 {
+ return out, scanner.ErrBlock
+ }
+ // Skip over simple init statements for some tokens (if, for, ...)
+ if lang.HasInit[in[0].Id] {
+ for in[endstmt-1].Id != lang.BraceBlock {
+ e2 := in[endstmt+1:].Index(lang.Semicolon)
+ if e2 == -1 {
+ return out, scanner.ErrBlock
+ }
+ endstmt += 1 + e2
+ }
+ }
+ o, err := p.ParseStmt(in[:endstmt])
+ if err != nil {
+ return out, err
+ }
+ out = append(out, o...)
+ in = in[endstmt+1:]
+ }
+ return out, err
+ */
+}
+
+func (p *Parser) ParseStmts(in Tokens) (out Tokens, err error) {
for len(in) > 0 {
endstmt := in.Index(lang.Semicolon)
if endstmt == -1 {
@@ -111,7 +140,7 @@ func (p *Parser) Parse(src string) (out Tokens, err error) {
}
func (p *Parser) ParseStmt(in Tokens) (out Tokens, err error) {
- log.Println("ParseStmt in:", in)
+ log.Println("ParseStmt in:", in, len(in))
if len(in) == 0 {
return nil, nil
}
@@ -370,17 +399,16 @@ func (p *Parser) ParseSwitch(in Tokens) (out Tokens, err error) {
}
out = init
}
+ condSwitch := false
if len(cond) > 0 {
if cond, err = p.ParseExpr(cond); err != nil {
return nil, err
}
- } else {
- cond = Tokens{{Id: lang.Ident, Str: "true"}}
+ out = append(out, cond...)
+ condSwitch = true
}
- out = append(out, cond...)
// Split switch body into case clauses.
clauses, err = p.Scan(in[len(in)-1].Block(), true)
- log.Println("## clauses:", clauses)
sc := clauses.SplitStart(lang.Case)
// Make sure that the default clause is the last.
lsc := len(sc) - 1
@@ -391,32 +419,58 @@ func (p *Parser) ParseSwitch(in Tokens) (out Tokens, err error) {
}
}
// Process each clause.
+ nc := len(sc) - 1
for i, cl := range sc {
- co, err := p.ParseCaseClause(cl, i)
+ co, err := p.ParseCaseClause(cl, i, nc, condSwitch)
if err != nil {
return nil, err
}
out = append(out, co...)
}
+ out = append(out, scanner.Token{Id: lang.Label, Str: p.scope + "e"})
return out, err
}
-func (p *Parser) ParseCaseClause(in Tokens, index int) (out Tokens, err error) {
- var initcond, init, cond, body Tokens
+func (p *Parser) ParseCaseClause(in Tokens, index, max int, condSwitch bool) (out Tokens, err error) {
+ in = append(in, scanner.Token{Id: lang.Semicolon}) // Force a ';' at the end of body clause.
+ var conds, body Tokens
tl := in.Split(lang.Colon)
if len(tl) != 2 {
return nil, errors.New("invalid case clause")
}
- initcond, body = tl[0][1:], tl[1]
- if ii := initcond.Index(lang.Semicolon); ii < 0 {
- cond = initcond
- } else {
- init = initcond[:ii]
- cond = initcond[ii+1:]
+ conds = tl[0][1:]
+ if body, err = p.ParseStmts(tl[1]); err != nil {
+ return out, err
+ }
+ lcond := conds.Split(lang.Comma)
+ for i, cond := range lcond {
+ if cond, err = p.ParseExpr(cond); err != nil {
+ return out, err
+ }
+ txt := fmt.Sprintf("%sc%d.%d", p.scope, index, i)
+ next := ""
+ if i == len(lcond)-1 { // End of cond: next, go to next clause or exit
+ if index < max {
+ next = fmt.Sprintf("%sc%d.%d", p.scope, index+1, 0)
+ } else {
+ next = p.scope + "e"
+ }
+ } else {
+ next = fmt.Sprintf("%sc%d.%d", p.scope, index, i+1)
+ }
+ out = append(out, scanner.Token{Id: lang.Label, Str: txt})
+ if len(cond) > 0 {
+ out = append(out, cond...)
+ if condSwitch {
+ out = append(out, scanner.Token{Id: lang.EqualSet})
+ }
+ out = append(out, scanner.Token{Id: lang.JumpFalse, Str: "JumpFalse " + next})
+ }
+ out = append(out, body...)
+ if i != len(lcond)-1 || index != max {
+ out = append(out, scanner.Token{Id: lang.Goto, Str: "Goto " + p.scope + "e"})
+ }
}
- lcond := cond.Split(lang.Comma)
- log.Println("# ParseCaseClause:", init, "cond:", cond, len(lcond))
- _ = body
return out, err
}
diff --git a/vm/vm.go b/vm/vm.go
index 0be9514..6f2df12 100644
--- a/vm/vm.go
+++ b/vm/vm.go
@@ -22,6 +22,7 @@ const (
Dup // addr -- value ; value = mem[addr]
Fdup // addr -- value ; value = mem[addr]
Equal // n1 n2 -- cond ; cond = n1 == n2
+ EqualSet // n1 n2 -- n1 cond ; cond = n1 == n2
Exit // -- ;
Greater // n1 n2 -- cond; cond = n1 > n2
Jump // -- ; ip += $1
@@ -48,6 +49,7 @@ var strop = [...]string{ // for VM tracing.
CallX: "CallX",
Dup: "Dup",
Equal: "Equal",
+ EqualSet: "EqualSet",
Exit: "Exit",
Fassign: "Fassign",
Fdup: "Fdup",
@@ -141,6 +143,16 @@ func (m *Machine) Run() (err error) {
case Equal:
mem[sp-2] = mem[sp-2].(int) == mem[sp-1].(int)
mem = mem[:sp-1]
+ case EqualSet:
+ if mem[sp-2].(int) == mem[sp-1].(int) {
+ // If equal then lhs and rhs are popped, replaced by test result, as in Equal.
+ mem[sp-2] = true
+ mem = mem[:sp-1]
+ } else {
+ // If not equal then the lhs is let on stack for further processing.
+ // This is used to simplify bytecode in case clauses of switch statments.
+ mem[sp-1] = false
+ }
case Exit:
return err
case Fdup: