1
1
from __future__ import annotations
2
2
3
+ import sys
3
4
from collections import deque
4
- from typing import Any
5
+ from typing import Any , Iterable
5
6
from unittest import mock
6
7
7
8
from interpreted import nodes
@@ -42,6 +43,7 @@ def __init__(self, parent=None) -> None:
42
43
self .set ("int" , Int ())
43
44
self .set ("float" , Float ())
44
45
self .set ("deque" , DequeConstructor ())
46
+ self .set ("enumerate" , Enumerate ())
45
47
46
48
def get (self , name ) -> Any :
47
49
return self .data .get (name , NOT_SET )
@@ -122,6 +124,20 @@ def call(self, _: Interpreter, args: list[Object]) -> Object:
122
124
raise InterpreterError (f"{ type (item ).__name__ } has no len()" )
123
125
124
126
127
+ class Enumerate (Function ):
128
+ def as_string (self ) -> str :
129
+ return "<function 'enumerate'>"
130
+
131
+ def arg_count (self ) -> int :
132
+ return 1
133
+
134
+ def call (self , _ : Interpreter , args : list [Object ]) -> Object :
135
+ super ().ensure_args (args )
136
+ # We don't have generator support yet :^)
137
+ pairs = [Tuple ([Value (idx ), val ]) for idx , val in enumerate (args [0 ])]
138
+ return List (pairs )
139
+
140
+
125
141
class Int (Function ):
126
142
def as_string (self ) -> str :
127
143
return "<function 'int'>"
@@ -257,6 +273,24 @@ def call(self, _: Interpreter, args: list[Object]) -> None:
257
273
self .wrapper ._data .append (item )
258
274
259
275
276
+ class Items (Function ):
277
+ def __init__ (self , wrapper : Dict ) -> None :
278
+ super ().__init__ ()
279
+ self .wrapper = wrapper
280
+
281
+ def as_string (self ) -> str :
282
+ return f"<method 'items' of { self .wrapper .repr ()} >"
283
+
284
+ def arg_count (self ) -> int :
285
+ return 0
286
+
287
+ def call (self , _ : Interpreter , args : list [Object ]) -> Any :
288
+ super ().ensure_args (args )
289
+ # We don't have generator support yet :^)
290
+ pairs = [Tuple (key_value_pair ) for key_value_pair in self .wrapper ._dict .items ()]
291
+ return List (pairs )
292
+
293
+
260
294
class PopLeft (Function ):
261
295
def __init__ (self , deque : Deque ) -> None :
262
296
super ().__init__ ()
@@ -354,29 +388,35 @@ def call(self, _: Interpreter, args: list[Object]) -> Value:
354
388
355
389
356
390
class List (Object ):
357
- def __init__ (self , elements ) -> None :
391
+ def __init__ (self , elements : Iterable [ Object ] ) -> None :
358
392
super ().__init__ ()
359
393
self ._data = elements
360
394
self .methods ["append" ] = Append (self )
361
395
362
396
def as_string (self ) -> str :
363
397
return "[" + ", " .join (item .repr () for item in self ._data ) + "]"
364
398
399
+ def __iter__ (self ) -> Iterable [Object ]:
400
+ return iter (self ._data )
401
+
365
402
366
403
class Tuple (Object ):
367
- def __init__ (self , elements ) -> None :
404
+ def __init__ (self , elements : Iterable [ Object ] ) -> None :
368
405
super ().__init__ ()
369
406
self ._data = elements
370
407
371
408
def as_string (self ) -> str :
372
409
return "(" + ", " .join (item .repr () for item in self ._data ) + ")"
373
410
411
+ def __iter__ (self ) -> Iterable [Object ]:
412
+ return iter (self ._data )
413
+
374
414
375
415
class Dict (Object ):
376
416
def __init__ (self , keys : list [Object ], values : list [Object ]) -> None :
377
417
super ().__init__ ()
378
-
379
- self ._dict = { key : value for key , value in zip ( keys , values , strict = True )}
418
+ self . _dict = { key : value for key , value in zip ( keys , values )}
419
+ self .methods [ "items" ] = Items ( self )
380
420
381
421
def as_string (self ) -> str :
382
422
return (
@@ -387,6 +427,9 @@ def as_string(self) -> str:
387
427
+ "}"
388
428
)
389
429
430
+ def __iter__ (self ) -> Iterable [Object ]:
431
+ return iter (list (self ._dict ))
432
+
390
433
391
434
def is_truthy (obj : Object ) -> bool :
392
435
if isinstance (obj , Value ):
@@ -486,14 +529,16 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
486
529
487
530
self .scope .set (node .name , function )
488
531
489
- def visit_Assign (self , node : Assign ) -> None :
490
- value = self .visit (node .value )
491
- assert len (node .targets ) == 1 # TODO
492
- target = node .targets [0 ]
493
-
532
+ def assign (self , target : Node , value : Object ) -> None :
494
533
if isinstance (target , Name ):
495
534
self .scope .set (target .id , value )
496
535
536
+ elif isinstance (target , (nodes .List , nodes .Tuple )) and isinstance (
537
+ value , (List , Tuple , Deque , Dict )
538
+ ):
539
+ for element , value in zip (target .elements , value ):
540
+ self .assign (element , value )
541
+
497
542
elif isinstance (target , Subscript ):
498
543
obj = self .visit (target .value )
499
544
@@ -517,7 +562,14 @@ def visit_Assign(self, node: Assign) -> None:
517
562
)
518
563
519
564
else :
520
- raise NotImplementedError (target ) # TODO
565
+ raise NotImplementedError (target , value ) # TODO
566
+
567
+ def visit_Assign (self , node : Assign ) -> None :
568
+ value = self .visit (node .value )
569
+ assert len (node .targets ) == 1 # TODO
570
+ target = node .targets [0 ]
571
+
572
+ self .assign (target , value )
521
573
522
574
def visit_AugAssign (self , node : AugAssign ) -> None :
523
575
increment = self .visit (node .value )
@@ -544,6 +596,29 @@ def visit_If(self, node: If) -> None:
544
596
for stmt in node .orelse :
545
597
self .visit (stmt )
546
598
599
+ def visit_For (self , node : nodes .For ) -> None :
600
+ if isinstance (node .iterable , (nodes .List , nodes .Tuple )):
601
+ elements = [self .visit (element ) for element in node .iterable .elements ]
602
+ elif isinstance (node .iterable , nodes .Dict ):
603
+ elements = [self .visit (element ) for element in node .iterable .keys ]
604
+ else :
605
+ elements = self .visit (node .iterable )
606
+ if not isinstance (elements , (List , Tuple , Deque , Dict )):
607
+ raise InterpreterError (
608
+ f"Object of type { type (elements ).__name__ } is not iterable"
609
+ )
610
+
611
+ for element in elements :
612
+ self .assign (node .target , element )
613
+
614
+ for stmt in node .body :
615
+ try :
616
+ self .visit (stmt )
617
+ except Break :
618
+ return
619
+ except Continue :
620
+ break
621
+
547
622
def visit_While (self , node : While ) -> None :
548
623
while is_truthy (self .visit (node .condition )):
549
624
for stmt in node .body :
@@ -792,3 +867,24 @@ def interpret(source: str) -> None:
792
867
return
793
868
794
869
Interpreter ().visit (module )
870
+
871
+
872
+ def main () -> None :
873
+ source = sys .stdin .read ()
874
+ module = interpret (source )
875
+ if module is None :
876
+ return
877
+
878
+ if "--pretty" in sys .argv :
879
+ try :
880
+ import black
881
+ except ImportError :
882
+ print ("Error: `black` needs to be installed for `--pretty` to work." )
883
+
884
+ print (black .format_str (repr (module ), mode = black .Mode ()))
885
+ else :
886
+ print (module )
887
+
888
+
889
+ if __name__ == "__main__" :
890
+ main ()
0 commit comments