wasm gc: fix issues with type inference

This commit is contained in:
Alexey Andreev 2024-09-04 20:58:29 +02:00
parent b36f38f48f
commit 29dec0962b
4 changed files with 38 additions and 20 deletions

View File

@ -22,20 +22,15 @@ import org.teavm.model.util.VariableCategoryProvider;
public class WasmGCVariableCategoryProvider implements VariableCategoryProvider {
private ClassHierarchy hierarchy;
private PreciseTypeInference inference;
public WasmGCVariableCategoryProvider(ClassHierarchy hierarchy) {
this.hierarchy = hierarchy;
}
public PreciseTypeInference getTypeInference() {
return inference;
}
@Override
public Object[] getCategories(Program program, MethodReference method) {
inference = new PreciseTypeInference(program, method, hierarchy);
inference.setPhisSkipped(true);
var inference = new PreciseTypeInference(program, method, hierarchy);
inference.setPhisSkipped(false);
inference.setBackPropagation(true);
var result = new Object[program.variableCount()];
for (int i = 0; i < program.variableCount(); ++i) {

View File

@ -25,6 +25,7 @@ import java.util.Set;
import org.teavm.ast.decompilation.Decompiler;
import org.teavm.backend.wasm.BaseWasmFunctionRepository;
import org.teavm.backend.wasm.WasmFunctionTypes;
import org.teavm.backend.wasm.gc.PreciseTypeInference;
import org.teavm.backend.wasm.gc.WasmGCVariableCategoryProvider;
import org.teavm.backend.wasm.gc.vtable.WasmGCVirtualTableProvider;
import org.teavm.backend.wasm.generate.gc.WasmGCNameProvider;
@ -220,7 +221,10 @@ public class WasmGCMethodGenerator implements BaseWasmFunctionRepository {
allocator.allocateRegisters(method.getReference(), method.getProgram(), friendlyToDebugger);
var ast = decompiler.decompileRegular(method);
var firstVar = method.hasModifier(ElementModifier.STATIC) ? 1 : 0;
var typeInference = categoryProvider.getTypeInference();
var typeInference = new PreciseTypeInference(method.getProgram(), method.getReference(), hierarchy);
typeInference.setPhisSkipped(true);
typeInference.setBackPropagation(true);
typeInference.ensure();
var registerCount = 0;
for (var i = 0; i < method.getProgram().variableCount(); ++i) {

View File

@ -547,6 +547,11 @@ public abstract class BaseTypeInference<T> {
push(insn.getValue(), insn.getFieldType());
}
@Override
public void visit(CastInstruction insn) {
push(insn.getValue(), insn.getTargetType());
}
private void push(Variable variable, ValueType type) {
if (nullTypes[variable.getIndex()]) {
stack.push(variable.getIndex());

View File

@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.teavm.common.DisjointSet;
import org.teavm.common.MutableGraphEdge;
import org.teavm.common.MutableGraphNode;
@ -45,7 +46,10 @@ public class RegisterAllocator {
}
public void allocateRegisters(MethodReference method, Program program, boolean debuggerFriendly) {
insertPhiArgumentsCopies(program);
var categories = variableCategoryProvider.getCategories(program, method);
var categoryList = new ArrayList<>(Arrays.asList(categories));
insertPhiArgumentsCopies(program, categoryList);
categories = categoryList.toArray();
InterferenceGraphBuilder interferenceBuilder = new InterferenceGraphBuilder();
LivenessAnalyzer liveness = new LivenessAnalyzer();
liveness.analyze(program, method.getDescriptor());
@ -53,7 +57,7 @@ public class RegisterAllocator {
program, method.parameterCount(), liveness);
DisjointSet congruenceClasses = buildPhiCongruenceClasses(program);
joinClassNodes(interferenceGraph, congruenceClasses);
removeRedundantCopies(program, interferenceGraph, congruenceClasses);
removeRedundantCopies(program, interferenceGraph, congruenceClasses, categories);
int[] classArray = congruenceClasses.pack(program.variableCount());
renameVariables(program, classArray);
int[] colors = new int[program.variableCount()];
@ -68,8 +72,13 @@ public class RegisterAllocator {
for (int cls : classArray) {
maxClass = Math.max(maxClass, cls + 1);
}
var categories = variableCategoryProvider.getCategories(program, method);
String[] names = getVariableNames(program, debuggerFriendly);
var newCategories = new Object[categories.length];
for (int i = 0; i < categories.length; ++i) {
var cls = classArray[i];
newCategories[cls] = categories[i];
}
categories = newCategories;
colorer.colorize(MutableGraphNode.toGraph(interferenceGraph), colors, categories, names);
int maxColor = 0;
@ -129,7 +138,7 @@ public class RegisterAllocator {
}
}
private void insertPhiArgumentsCopies(Program program) {
private void insertPhiArgumentsCopies(Program program, List<Object> categories) {
List<List<Incoming>> catchIncomingsByVariable = new ArrayList<>(
Collections.nCopies(program.variableCount(), null));
@ -151,14 +160,14 @@ public class RegisterAllocator {
}
catchIncomings.add(incoming);
} else {
insertCopy(incoming, blockMap);
insertCopy(incoming, blockMap, categories);
incomingsToRepeat.add(incoming);
}
}
}
for (Incoming incoming : incomingsToRepeat) {
insertCopy(incoming, blockMap);
insertCopy(incoming, blockMap, categories);
}
}
@ -167,7 +176,7 @@ public class RegisterAllocator {
for (BasicBlock block : program.getBasicBlocks()) {
for (Phi phi : block.getPhis()) {
addExceptionHandlingCopies(catchIncomingsByVariable, phi.getReceiver(), block,
program, block.getFirstInstruction().getLocation(), nextInstructions);
program, block.getFirstInstruction().getLocation(), nextInstructions, categories);
}
if (!nextInstructions.isEmpty()) {
@ -180,7 +189,7 @@ public class RegisterAllocator {
Variable[] definedVariables = definitionExtractor.getDefinedVariables();
for (Variable definedVariable : definedVariables) {
addExceptionHandlingCopies(catchIncomingsByVariable, definedVariable, block,
program, instruction.getLocation(), nextInstructions);
program, instruction.getLocation(), nextInstructions, categories);
}
if (!nextInstructions.isEmpty()) {
@ -198,6 +207,7 @@ public class RegisterAllocator {
BasicBlock block = incoming.getSource();
Variable copy = program.createVariable();
categories.add(categories.get(incoming.getPhi().getReceiver().getIndex()));
copy.setLabel(incoming.getPhi().getReceiver().getLabel());
copy.setDebugName(incoming.getPhi().getReceiver().getDebugName());
@ -213,7 +223,8 @@ public class RegisterAllocator {
}
private void addExceptionHandlingCopies(List<List<Incoming>> catchIncomingsByVariable, Variable definedVariable,
BasicBlock block, Program program, TextLocation location, List<Instruction> nextInstructions) {
BasicBlock block, Program program, TextLocation location, List<Instruction> nextInstructions,
List<Object> categories) {
if (definedVariable.getIndex() >= catchIncomingsByVariable.size()) {
return;
}
@ -228,6 +239,7 @@ public class RegisterAllocator {
Variable copy = program.createVariable();
copy.setLabel(incoming.getPhi().getReceiver().getLabel());
copy.setDebugName(incoming.getPhi().getReceiver().getDebugName());
categories.add(categories.get(incoming.getPhi().getReceiver().getIndex()));
AssignInstruction copyInstruction = new AssignInstruction();
copyInstruction.setReceiver(copy);
@ -242,11 +254,12 @@ public class RegisterAllocator {
}
}
private void insertCopy(Incoming incoming, Map<BasicBlock, BasicBlock> blockMap) {
private void insertCopy(Incoming incoming, Map<BasicBlock, BasicBlock> blockMap, List<Object> categories) {
Phi phi = incoming.getPhi();
Program program = phi.getBasicBlock().getProgram();
AssignInstruction copyInstruction = new AssignInstruction();
Variable firstCopy = program.createVariable();
categories.add(categories.get(phi.getReceiver().getIndex()));
firstCopy.setLabel(phi.getReceiver().getLabel());
firstCopy.setDebugName(phi.getReceiver().getDebugName());
copyInstruction.setReceiver(firstCopy);
@ -274,7 +287,7 @@ public class RegisterAllocator {
}
private void removeRedundantCopies(Program program, List<MutableGraphNode> interferenceGraph,
DisjointSet congruenceClasses) {
DisjointSet congruenceClasses, Object[] categories) {
for (int i = 0; i < program.basicBlockCount(); ++i) {
BasicBlock block = program.basicBlockAt(i);
Instruction nextInsn;
@ -297,7 +310,8 @@ public class RegisterAllocator {
break;
}
}
if (!interfere) {
if (!interfere && Objects.equals(categories[assignment.getReceiver().getIndex()],
categories[assignment.getAssignee().getIndex()])) {
int newClass = congruenceClasses.union(copyClass, origClass);
insn.delete();
if (newClass == interferenceGraph.size()) {