Skip to content

Print infix operations in infix form #22854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions compiler/src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import java.nio.CharBuffer
import scala.io.Codec
import Int.MaxValue
import Names.*, StdNames.*, Contexts.*, Symbols.*, Flags.*, NameKinds.*, Types.*
import util.Chars.{isOperatorPart, digit2int}
import util.Chars.{isOperatorPart, isIdentifierPart, digit2int}
import Decorators.*
import Definitions.*
import nme.*
Expand Down Expand Up @@ -78,9 +78,22 @@ object NameOps {
def isUnapplyName: Boolean = name == nme.unapply || name == nme.unapplySeq
def isRightAssocOperatorName: Boolean = name.lastPart.last == ':'

def isOperatorName: Boolean = name match
case name: SimpleName => name.exists(isOperatorPart)
case _ => false
/** Does this name match `[{letter | digit} '_'] op`?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*
* See examples in [[NameOpsTest]].
*/
def isOperatorName: Boolean =
name match
case name: SimpleName =>
var i = name.length - 1
// Ends with operator characters
while i >= 0 && isOperatorPart(name(i)) do i -= 1
if i == -1 then return true
// Optionnally prefixed with alpha-numeric characters followed by `_`
if name(i) != '_' then return false
while i >= 0 && isIdentifierPart(name(i)) do i -= 1
i == -1
case _ => false

/** Is name of a variable pattern? */
def isVarPattern: Boolean =
Expand Down
14 changes: 13 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import dotty.tools.dotc.util.SourcePosition
import dotty.tools.dotc.ast.untpd.{MemberDef, Modifiers, PackageDef, RefTree, Template, TypeDef, ValOrDefDef}
import cc.*
import dotty.tools.dotc.parsing.JavaParsers
import dotty.tools.dotc.transform.TreeExtractors.BinaryOp

class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {

Expand Down Expand Up @@ -387,6 +388,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {

def optDotPrefix(tree: This) = optText(tree.qual)(_ ~ ".") provided !isLocalThis(tree)

/** Should a binary operation with this operator be printed infix? */
def isInfix(op: Symbol) =
op.exists && (op.isDeclaredInfix || op.name.isOperatorName)

def caseBlockText(tree: Tree): Text = tree match {
case Block(stats, expr) => toText(stats :+ expr, "\n")
case expr => toText(expr)
Expand Down Expand Up @@ -478,6 +483,13 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
optDotPrefix(tree) ~ keywordStr("this") ~ idText(tree)
case Super(qual: This, mix) =>
optDotPrefix(qual) ~ keywordStr("super") ~ optText(mix)("[" ~ _ ~ "]")
case BinaryOp(l, op, r) if isInfix(op) =>
val isRightAssoc = op.name.endsWith(":")
val opPrec = parsing.precedence(op.name)
val leftPrec = if isRightAssoc then opPrec + 1 else opPrec
val rightPrec = if !isRightAssoc then opPrec + 1 else opPrec
changePrec(opPrec):
atPrec(leftPrec)(toText(l)) ~ " " ~ toText(op.name) ~ " " ~ atPrec(rightPrec)(toText(r))
case app @ Apply(fun, args) =>
if (fun.hasType && fun.symbol == defn.throwMethod)
changePrec (GlobalPrec) {
Expand All @@ -504,7 +516,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
}
}
case Typed(expr, tpt) =>
changePrec(InfixPrec) {
changePrec(DotPrec) {
if isWildcardStarArg(tree) then
expr match
case Ident(nme.WILDCARD_STAR) =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/test-resources/repl/i13181
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
scala> scala.compiletime.codeOf(1+2)
val res0: String = 1.+(2)
val res0: String = 1 + 2
15 changes: 15 additions & 0 deletions compiler/test/dotty/tools/dotc/core/NameOpsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dotty.tools.dotc.core

import dotty.tools.dotc.core.NameOps.isOperatorName
import dotty.tools.dotc.core.Names.{termName, SimpleName}

import org.junit.Test

class NameOpsTest:
@Test def isOperatorNamePos: Unit =
for name <- List("+", "::", "frozen_=:=", "$_+", "a2_+", "a_b_+") do
assert(isOperatorName(termName(name)))

@Test def isOperatorNameNeg: Unit =
for name <- List("foo", "*_*", "<init>", "$reserved", "a*", "2*") do
assert(!isOperatorName(termName(name)))
2 changes: 1 addition & 1 deletion tests/printing/export-param-flags.check
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
package <empty> {
final lazy module val A: A = new A()
final module class A() extends Object() { this: A.type =>
inline def inlinedParam(inline x: Int): Int = x.+(x):Int
inline def inlinedParam(inline x: Int): Int = (x + x):Int
}
final lazy module val Exported: Exported = new Exported()
final module class Exported() extends Object() { this: Exported.type =>
Expand Down
37 changes: 37 additions & 0 deletions tests/printing/infix-operations.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[[syntax trees at end of typer]] // tests/printing/infix-operations.scala
package <empty> {
class C(a: Int) extends Object() {
private[this] val a: Int
def foo(b: Int): C = this
def +(b: Int): C = this
}
final lazy module val infix-operations$package: infix-operations$package =
new infix-operations$package()
final module class infix-operations$package() extends Object() {
this: infix-operations$package.type =>
@main def main: Unit =
{
val v1: Int = 1 + 2 + 3
val v2: Int = 1 + (2 + 3)
val v3: Int = 1 + 2 * 3
val v4: Int = (1 + 2) * 3
val v5: Int = (1 + 2):Int
val v6: Int = (1 + 2):Int
val v7: Boolean = (1 < 2):Boolean
val v8: Boolean = (1 < 2):Boolean
val c: C = new C(2)
val v9: C = c.foo(3)
val v10: C = c + 3
()
}
}
final class main() extends Object() {
<static> def main(args: Array[String]): Unit =
try main catch
{
case error @ _:scala.util.CommandLineParser.ParseError =>
scala.util.CommandLineParser.showError(error)
}
}
}

17 changes: 17 additions & 0 deletions tests/printing/infix-operations.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class C(a: Int):
def foo(b: Int): C = this
def +(b: Int): C = this

@main def main =
val v1 = 1 + 2 + 3
val v2 = 1 + (2 + 3)
val v3 = 1 + 2 * 3
val v4 = (1 + 2) * 3
val v5 = (1 + 2):Int
val v6 = 1 + 2:Int // same as above
val v7 = (1 < 2):Boolean
val v8 = 1 < 2:Boolean // same as above

val c = new C(2) // must not be printed in infix form
val v9 = c.foo(3) // must not be printed in infix form
val v10 = c + 3
24 changes: 12 additions & 12 deletions tests/printing/lambdas.check
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
package <empty> {
final lazy module val Main: Main = new Main()
final module class Main() extends Object() { this: Main.type =>
val f1: Int => Int = (x: Int) => x.+(1)
val f2: (Int, Int) => Int = (x: Int, y: Int) => x.+(y)
val f3: Int => Int => Int = (x: Int) => (y: Int) => x.+(y)
val f4: [T] => (x: Int) => Int = [T >: Nothing <: Any] => (x: Int) => x.+(1)
val f1: Int => Int = (x: Int) => x + 1
val f2: (Int, Int) => Int = (x: Int, y: Int) => x + y
val f3: Int => Int => Int = (x: Int) => (y: Int) => x + y
val f4: [T] => (x: Int) => Int = [T >: Nothing <: Any] => (x: Int) => x + 1
val f5: [T] => (x: Int) => Int => Int = [T >: Nothing <: Any] => (x: Int)
=> (y: Int) => x.+(y)
=> (y: Int) => x + y
val f6: Int => Int = (x: Int) =>
{
val x2: Int = x.+(1)
x2.+(1)
val x2: Int = x + 1
x2 + 1
}
def f7(x: Int): Int = x.+(1)
def f7(x: Int): Int = x + 1
val f8: Int => Int = (x: Int) => Main.f7(x)
val l: List[Int] = List.apply[Int]([1,2,3 : Int]*)
Main.l.map[Int]((_$1: Int) => _$1.+(1))
Main.l.map[Int]((x: Int) => x.+(1))
Main.l.map[Int]((_$1: Int) => _$1 + 1)
Main.l.map[Int]((x: Int) => x + 1)
Main.l.map[Int]((x: Int) =>
{
val x2: Int = x.+(1)
x2.+(1)
val x2: Int = x + 1
x2 + 1
}
)
Main.l.map[Int]((x: Int) => Main.f7(x))
Expand Down
2 changes: 1 addition & 1 deletion tests/printing/posttyper/i22533.check
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ package <empty> {
override def equals(x$0: Any): Boolean =
x$0 match
{
case x$0 @ _:Foo @unchecked => this.u.==(x$0.u)
case x$0 @ _:Foo @unchecked => this.u == x$0.u
case _ => false
}
private[this] val u: Int
Expand Down
4 changes: 2 additions & 2 deletions tests/printing/transformed/lazy-vals-legacy.check
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ package <empty> {
{
val flag: Long = scala.runtime.LazyVals.get(this, A.OFFSET$_m_0)
val state: Long = scala.runtime.LazyVals.STATE(flag, 0)
if state.==(3) then return A.x$lzy1 else
if state.==(0) then
if state == 3 then return A.x$lzy1 else
if state == 0 then
if scala.runtime.LazyVals.CAS(this, A.OFFSET$_m_0, flag, 1, 0)
then
try
Expand Down
8 changes: 4 additions & 4 deletions tests/printing/transformed/lazy-vals-new.check
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ package <empty> {
{
val result: Object = A.x$lzy1
if result.isInstanceOf[Int] then Int.unbox(result) else
if result.eq(scala.runtime.LazyVals.NullValue) then Int.unbox(null)
if result eq scala.runtime.LazyVals.NullValue then Int.unbox(null)
else Int.unbox(A.x$lzyINIT1())
}
private def x$lzyINIT1(): Object =
while <empty> do
{
val current: Object = A.x$lzy1
if current.eq(null) then
if current eq null then
if
scala.runtime.LazyVals.objCAS(this, A.OFFSET$_m_0, null,
scala.runtime.LazyVals.Evaluating)
Expand All @@ -42,7 +42,7 @@ package <empty> {
try
{
resultNullable = Int.box(2)
if resultNullable.eq(null) then
if resultNullable eq null then
result = scala.runtime.LazyVals.NullValue else
result = resultNullable
()
Expand All @@ -69,7 +69,7 @@ package <empty> {
current.isInstanceOf[
scala.runtime.LazyVals.LazyVals$LazyValControlState]
then
if current.eq(scala.runtime.LazyVals.Evaluating) then
if current eq scala.runtime.LazyVals.Evaluating then
{
scala.runtime.LazyVals.objCAS(this, A.OFFSET$_m_0, current,
new scala.runtime.LazyVals.LazyVals$Waiting())
Expand Down
2 changes: 1 addition & 1 deletion tests/run-macros/lambda-extractor-2.check
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
1.+(2)
1 + 2
{
scala.Predef.println(1)
2
Expand Down
2 changes: 1 addition & 1 deletion tests/run/i13181.scala
Original file line number Diff line number Diff line change
@@ -1 +1 @@
@main def Test = assert(scala.compiletime.codeOf(1+2) == "1.+(2)")
@main def Test = assert(scala.compiletime.codeOf(1+2) == "1 + 2")
2 changes: 1 addition & 1 deletion tests/run/typeCheckErrors.check
Original file line number Diff line number Diff line change
@@ -1 +1 @@
List(Error(value check is not a member of Unit,compileError("1" * 2).check(""),22,Typer), Error(argument to compileError must be a statically known String but was: augmentString("1").*(2),compileError("1" * 2).check(""),13,Typer))
List(Error(value check is not a member of Unit,compileError("1" * 2).check(""),22,Typer), Error(argument to compileError must be a statically known String but was: augmentString("1") * 2,compileError("1" * 2).check(""),13,Typer))