Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ private fun MutableList<Backend.OutputFileSpecification>.addLib(
path = filePath("src", "lib.rs"),
content = Rust.SourceFile(
pos,
attrs = listOf(allowWarnings(pos)),
attrs = allowWarnings(pos),
items = buildList {
// Separate mods.
declareSubmods(pos, allModKids[libraryConfiguration.libraryRoot] ?: setOf())
Expand Down
14 changes: 6 additions & 8 deletions be-rust/src/commonMain/kotlin/lang/temper/be/rust/RustExt.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ import lang.temper.type2.withType
* change in edition. We should investigate this more in the future, though. It would
* be good to support latest expectations.
*/
internal fun allowWarnings(pos: Position): Rust.AttrInner = Rust.AttrInner(
pos,
Rust.Call(
pos,
"allow".toId(pos),
listOf("dependency_on_unit_never_type_fallback", "warnings").map { it.toId(pos) },
),
)
internal fun allowWarnings(pos: Position) = run {
// Separate them with `warnings` first, to prevent warnings against the other on older rust.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least, I seemed to see effects per the comment above, but I didn't spend a lot of time on this aspect.

listOf("warnings", "dependency_on_unit_never_type_fallback").map { key ->
Rust.AttrInner(pos, Rust.Call(pos, "allow".toId(pos), listOf(key.toId(pos))))
}
}

internal fun makeError(pos: Position) = Rust.Call(pos, callee = ERROR_NEW_NAME.toId(pos), args = listOf())

Expand Down
145 changes: 99 additions & 46 deletions be-rust/src/commonMain/kotlin/lang/temper/be/rust/RustTranslator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import lang.temper.be.tmpl.typeOrInvalid
import lang.temper.common.compatRemoveLast
import lang.temper.common.subListToEnd
import lang.temper.frontend.ModuleNamingContext
import lang.temper.frontend.typestage.findOverrides
import lang.temper.interp.importExport.STANDARD_LIBRARY_NAME
import lang.temper.lexer.withTemperAwareExtension
import lang.temper.library.LibraryConfigurations
import lang.temper.log.FilePath
import lang.temper.log.LogSink
import lang.temper.log.Position
import lang.temper.log.last
import lang.temper.name.BuiltinName
Expand All @@ -38,6 +40,7 @@ import lang.temper.name.Temporary
import lang.temper.type.Abstractness
import lang.temper.type.MethodKind
import lang.temper.type.MethodShape
import lang.temper.type.PropertyShape
import lang.temper.type.TypeDefinition
import lang.temper.type.TypeFormal
import lang.temper.type.TypeShape
Expand All @@ -51,6 +54,7 @@ import lang.temper.type2.NonNullType
import lang.temper.type2.Nullity
import lang.temper.type2.Signature2
import lang.temper.type2.Type2
import lang.temper.type2.TypeContext2
import lang.temper.type2.TypeParamRef
import lang.temper.type2.ValueFormalKind
import lang.temper.type2.hackMapOldStyleToNew
Expand Down Expand Up @@ -94,6 +98,7 @@ class RustTranslator(
private var insideMutableType = false
private val failVars = mutableSetOf<ResolvedName>()
private val functionContextStack = mutableListOf<FunctionContext>()
private val logSink = LogSink.devNull // TODO what?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need some log sink, and this works for now, but maybe should reconsider or even just pass log sinks into translators. I've wanted that more than once before.

private val loopLabels = mutableListOf<Rust.Id?>()
private val moduleInits = mutableListOf<Rust.Statement>()
private val moduleItems = mutableListOf<Rust.Item>()
Expand All @@ -102,6 +107,7 @@ class RustTranslator(
}
private val testItems = mutableListOf<Rust.Item>()
private val traitImports = mutableSetOf<Rust.Path>()
private val typeContext = TypeContext2()

fun translateModule(): Backend.TranslatedFileSpecification {
// Preprocess tops.
Expand Down Expand Up @@ -137,7 +143,7 @@ class RustTranslator(
path = makeSrcFilePath(relPath.withTemperAwareExtension("")),
content = Rust.SourceFile(
pos,
attrs = listOf(allowWarnings(module.pos)),
attrs = allowWarnings(module.pos),
items = buildList {
// Declare submodules, except for root that needs to declare in lib file.
if (!isRoot) {
Expand Down Expand Up @@ -310,7 +316,7 @@ class RustTranslator(
// And for now, skip those with rest parameters. TODO Extract to list value?
fn.parameters.restParameter != null && return
// Build the builder.
// Here we make `WhateverBuilder` for requireds and/or `WhateverBuilderOptions` structs for optionals.
// Here we make `WhateverBuilder` for requireds and/or `WhateverOptions` structs for optionals.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed in passing that this comment was wrong vs current behavior.

// Alternatively, could make a Java-style builder, which is common in Rust, but it takes less advantage of
// standard static checking that we get with pub struct fields.
val pos = fn.pos
Expand Down Expand Up @@ -775,8 +781,8 @@ class RustTranslator(
for (supMethod in supShape.methods) {
maybeAddTraitForwarder(
pos,
instanceMethods,
isInterface = isInterface,
decl = decl,
instanceMethods = instanceMethods,
methodKind = supMethod.methodKind,
methodName = supMethod.name.displayName,
returnType = supMethod.descriptor.orInvalid.returnType2,
Expand All @@ -788,8 +794,8 @@ class RustTranslator(
if (supProperty.getter == null) {
maybeAddTraitForwarder(
pos,
instanceMethods,
isInterface = isInterface,
decl = decl,
instanceMethods = instanceMethods,
methodKind = MethodKind.Getter,
methodName = BuiltinName("get.${supProperty.symbol.text}").displayName,
superShape = supProperty,
Expand All @@ -798,8 +804,8 @@ class RustTranslator(
if (supProperty.setter == null && supProperty.hasSetter) {
maybeAddTraitForwarder(
pos,
instanceMethods,
isInterface = isInterface,
decl = decl,
instanceMethods = instanceMethods,
methodKind = MethodKind.Setter,
methodName = BuiltinName("set.${supProperty.symbol.text}").displayName,
superShape = supProperty,
Expand All @@ -815,18 +821,19 @@ class RustTranslator(

private fun MutableList<Rust.Item>.maybeAddTraitForwarder(
pos: Position,
decl: TmpL.TypeDeclaration,
instanceMethods: Map<String, TmpL.InstanceMethod>,
isInterface: Boolean,
methodKind: MethodKind,
methodName: String,
superShape: VisibleMemberShape,
returnType: Type2? = null,
) {
val isInterface = decl.kind == TmpL.TypeDeclarationKind.Interface
when {
isInterface -> buildForwarderForTrait(pos, superShape, methodKind)
isInterface -> buildForwarderFromInterfaceToTrait(pos, superShape, methodKind)
else -> when (val method = instanceMethods[methodName]) {
null -> listOf()
else -> buildForwarder(method, returnType = returnType)
null -> buildForwarderFromClassToTrait(pos, decl, superShape, methodKind)
else -> buildForwarder(method, returnType)
}
}.also { addAll(it) } // only one expected here, but meh
}
Expand Down Expand Up @@ -1073,59 +1080,112 @@ class RustTranslator(
return translateMethodLike(method, block = block, forTrait = true, returnType = effectiveReturnType)
}

private fun buildForwarderFromClassToTrait(
pos: Position,
decl: TmpL.TypeDeclaration,
superShape: VisibleMemberShape,
methodKind: MethodKind,
): List<Rust.Item> = run {
// We're here because this class has no matching member, so walk its supertypes to match the super method.
// The method we inherit closest might be on a different branch.
val overrides = findOverrides(decl.typeShape, superShape, typeContext, logSink)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I factored out findOverrides from existing logic. I think it meets the needs here.

val override = overrides.find overrides@{ override ->
when (val foundMember = override.superTypeMember) {
is MethodShape -> !foundMember.isPureVirtual
is PropertyShape -> when (methodKind) {
MethodKind.Getter -> foundMember.getter
MethodKind.Setter -> foundMember.setter
else -> return@overrides false
}.let { foundName ->
// Interfaces can only provide property implementations with methods, so look at those.
foundMember.enclosingType.methods.any { method ->
!method.isPureVirtual && method.methodKind == methodKind && method.name == foundName
}
}
else -> false
}
}
// Having selected an override, forward to it with simple self.
val targetType = override?.superTypeMember?.enclosingType
if (targetType == superShape.enclosingType) {
// Just let the trait handle this one directly. Self-call here is infinite recursion.
return listOf()
}
buildForwarderToTrait(pos, targetType, superShape, methodKind) result@{ traitType, methodId, argIds ->
Rust.Call(
pos,
callee = traitType.extendWith(methodId.deepCopy()),
args = buildList {
add("self".toKeyId(pos))
addAll(argIds)
},
)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now factored to share logic with trait-wrapper-to-trait forwarding that I put in on the previous pr. The main difference is what happens inside the forwarding method, so that's what the callback is for.

}

/**
* Build a forwarder from a trait wrapper to a trait method that *isn't*
* overridden in the current trait. We need this to handle methods for
* which we have only frontend descriptions, not tmpl.
*/
private fun buildForwarderForTrait(
private fun buildForwarderFromInterfaceToTrait(
pos: Position,
shape: VisibleMemberShape,
methodKind: MethodKind,
): List<Rust.Item> = run {
// From trait wrapper to trait, just forward the call with the unwrapped innards.
buildForwarderToTrait(pos, shape.enclosingType, shape, methodKind) result@{ traitType, methodId, argIds ->
Rust.Call(
pos,
callee = traitType.extendWith(methodId.deepCopy()),
args = buildList {
add("self".toKeyId(pos).member("0", notMethod = true).deref().ref())
addAll(argIds)
},
)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This retains behavior from before while factoring most of the logic to share with the new class-to-trait forwarding above.

}

private fun buildForwarderToTrait(
pos: Position,
targetType: TypeShape?,
shape: VisibleMemberShape,
methodKind: MethodKind,
/** Trait type, method id, and arg ids, all ready to be used. */
buildResult: (Rust.Path, Rust.Id, List<Rust.Id>) -> Rust.Expr?,
): List<Rust.Item> = run {
val selfParam = Rust.RefType(pos, "self".toKeyId(pos))
val selfArg = "self".toKeyId(pos).member("0", notMethod = true).deref().ref()
val enclosingType =
(translateTypeDefinition(shape.enclosingType, pos) as? Rust.Path)?.suffixed(TRAIT_NAME_SUFFIX)
val traitType = targetType?.let {
(translateTypeDefinition(targetType, pos) as? Rust.Path)?.suffixed(TRAIT_NAME_SUFFIX)
}
when (methodKind) {
MethodKind.Normal -> {
val method = shape as MethodShape
val methodId = translateIdFromName(pos, method.name as ResolvedName, NameStyle.Snake)
val sig = method.descriptor ?: return listOf()
val argNames = (1..<sig.requiredInputTypes.size + sig.optionalInputTypes.size).map { "arg$it" }
// We don't have param names here, so invent some.
val argIds = (1..<sig.requiredInputTypes.size + sig.optionalInputTypes.size).map { arg ->
"arg$arg".toId(pos)
}
Rust.Function(
pos,
id = methodId,
id = methodId.deepCopy(),
params = buildList {
add(selfParam)
var index = 0
for (paramType in sig.requiredInputTypes.subListToEnd(1)) {
val paramName = argNames[index++].toId(pos)
val paramName = argIds[index++].deepCopy()
val translatedType = translateType(paramType, pos)
add(Rust.FunctionParam(pos, paramName, translatedType))
}
for (paramType in sig.optionalInputTypes) {
val paramName = argNames[index++].toId(pos)
val paramName = argIds[index++].deepCopy()
val translatedType = translateType(paramType, pos).option()
add(Rust.FunctionParam(pos, paramName, translatedType))
}
},
returnType = method.descriptor?.let { translateType(it.returnType2, pos = pos) },
block = Rust.Block(
pos,
result = enclosingType?.let { type ->
Rust.Call(
pos,
callee = type.extendWith(methodId.deepCopy()),
args = buildList {
add(selfArg)
for (argName in argNames) {
add(argName.toId(pos))
}
},
)
},
),
block = Rust.Block(pos, result = traitType?.let { buildResult(it, methodId, argIds) }),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual method content is now handled per above callbacks.

)
}
MethodKind.Getter -> {
Expand All @@ -1135,15 +1195,12 @@ class RustTranslator(
is Type2 -> descriptor
else -> null
}?.let { translateType(it, pos = pos) }
val call = enclosingType?.let { type ->
Rust.Call(pos, type.extendWith(methodId.deepCopy()), listOf(selfArg))
}
Rust.Function(
pos,
id = methodId,
id = methodId.deepCopy(),
params = listOf(selfParam),
returnType = returnType,
block = Rust.Block(pos, result = call),
block = Rust.Block(pos, result = traitType?.let { buildResult(it, methodId, listOf()) }),
)
}
MethodKind.Setter -> {
Expand All @@ -1155,15 +1212,11 @@ class RustTranslator(
}?.let { translateType(it, pos = pos) }
// We don't have param names here, so invent one.
val value = "value".toId(pos)
val call = enclosingType?.let { type ->
val args = listOf(selfArg, value.deepCopy())
Rust.Call(pos, type.extendWith(methodId.deepCopy()), args)
}
Rust.Function(
pos,
id = methodId,
params = listOf(selfParam, Rust.FunctionParam(pos, value, propertyType)),
block = Rust.Block(pos, result = call),
params = listOf(selfParam, Rust.FunctionParam(pos, value.deepCopy(), propertyType)),
block = Rust.Block(pos, result = traitType?.let { buildResult(it, methodId, listOf(value)) }),
)
}
else -> return listOf()
Expand Down
Loading
Loading