wasm gc: fix issues with type inference

This commit is contained in:
Alexey Andreev 2024-09-04 20:58:29 +02:00
parent b36f38f48f
commit 29dec0962b
4 changed files with 38 additions and 20 deletions

View File

@ -22,20 +22,15 @@ import org.teavm.model.util.VariableCategoryProvider;
public class WasmGCVariableCategoryProvider implements VariableCategoryProvider { public class WasmGCVariableCategoryProvider implements VariableCategoryProvider {
private ClassHierarchy hierarchy; private ClassHierarchy hierarchy;
private PreciseTypeInference inference;
public WasmGCVariableCategoryProvider(ClassHierarchy hierarchy) { public WasmGCVariableCategoryProvider(ClassHierarchy hierarchy) {
this.hierarchy = hierarchy; this.hierarchy = hierarchy;
} }
public PreciseTypeInference getTypeInference() {
return inference;
}
@Override @Override
public Object[] getCategories(Program program, MethodReference method) { public Object[] getCategories(Program program, MethodReference method) {
inference = new PreciseTypeInference(program, method, hierarchy); var inference = new PreciseTypeInference(program, method, hierarchy);
inference.setPhisSkipped(true); inference.setPhisSkipped(false);
inference.setBackPropagation(true); inference.setBackPropagation(true);
var result = new Object[program.variableCount()]; var result = new Object[program.variableCount()];
for (int i = 0; i < program.variableCount(); ++i) { for (int i = 0; i < program.variableCount(); ++i) {

View File

@ -25,6 +25,7 @@ import java.util.Set;
import org.teavm.ast.decompilation.Decompiler; import org.teavm.ast.decompilation.Decompiler;
import org.teavm.backend.wasm.BaseWasmFunctionRepository; import org.teavm.backend.wasm.BaseWasmFunctionRepository;
import org.teavm.backend.wasm.WasmFunctionTypes; import org.teavm.backend.wasm.WasmFunctionTypes;
import org.teavm.backend.wasm.gc.PreciseTypeInference;
import org.teavm.backend.wasm.gc.WasmGCVariableCategoryProvider; import org.teavm.backend.wasm.gc.WasmGCVariableCategoryProvider;
import org.teavm.backend.wasm.gc.vtable.WasmGCVirtualTableProvider; import org.teavm.backend.wasm.gc.vtable.WasmGCVirtualTableProvider;
import org.teavm.backend.wasm.generate.gc.WasmGCNameProvider; import org.teavm.backend.wasm.generate.gc.WasmGCNameProvider;
@ -220,7 +221,10 @@ public class WasmGCMethodGenerator implements BaseWasmFunctionRepository {
allocator.allocateRegisters(method.getReference(), method.getProgram(), friendlyToDebugger); allocator.allocateRegisters(method.getReference(), method.getProgram(), friendlyToDebugger);
var ast = decompiler.decompileRegular(method); var ast = decompiler.decompileRegular(method);
var firstVar = method.hasModifier(ElementModifier.STATIC) ? 1 : 0; var firstVar = method.hasModifier(ElementModifier.STATIC) ? 1 : 0;
var typeInference = categoryProvider.getTypeInference(); var typeInference = new PreciseTypeInference(method.getProgram(), method.getReference(), hierarchy);
typeInference.setPhisSkipped(true);
typeInference.setBackPropagation(true);
typeInference.ensure();
var registerCount = 0; var registerCount = 0;
for (var i = 0; i < method.getProgram().variableCount(); ++i) { for (var i = 0; i < method.getProgram().variableCount(); ++i) {

View File

@ -547,6 +547,11 @@ public abstract class BaseTypeInference<T> {
push(insn.getValue(), insn.getFieldType()); push(insn.getValue(), insn.getFieldType());
} }
@Override
public void visit(CastInstruction insn) {
push(insn.getValue(), insn.getTargetType());
}
private void push(Variable variable, ValueType type) { private void push(Variable variable, ValueType type) {
if (nullTypes[variable.getIndex()]) { if (nullTypes[variable.getIndex()]) {
stack.push(variable.getIndex()); stack.push(variable.getIndex());

View File

@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import org.teavm.common.DisjointSet; import org.teavm.common.DisjointSet;
import org.teavm.common.MutableGraphEdge; import org.teavm.common.MutableGraphEdge;
import org.teavm.common.MutableGraphNode; import org.teavm.common.MutableGraphNode;
@ -45,7 +46,10 @@ public class RegisterAllocator {
} }
public void allocateRegisters(MethodReference method, Program program, boolean debuggerFriendly) { public void allocateRegisters(MethodReference method, Program program, boolean debuggerFriendly) {
insertPhiArgumentsCopies(program); var categories = variableCategoryProvider.getCategories(program, method);
var categoryList = new ArrayList<>(Arrays.asList(categories));
insertPhiArgumentsCopies(program, categoryList);
categories = categoryList.toArray();
InterferenceGraphBuilder interferenceBuilder = new InterferenceGraphBuilder(); InterferenceGraphBuilder interferenceBuilder = new InterferenceGraphBuilder();
LivenessAnalyzer liveness = new LivenessAnalyzer(); LivenessAnalyzer liveness = new LivenessAnalyzer();
liveness.analyze(program, method.getDescriptor()); liveness.analyze(program, method.getDescriptor());
@ -53,7 +57,7 @@ public class RegisterAllocator {
program, method.parameterCount(), liveness); program, method.parameterCount(), liveness);
DisjointSet congruenceClasses = buildPhiCongruenceClasses(program); DisjointSet congruenceClasses = buildPhiCongruenceClasses(program);
joinClassNodes(interferenceGraph, congruenceClasses); joinClassNodes(interferenceGraph, congruenceClasses);
removeRedundantCopies(program, interferenceGraph, congruenceClasses); removeRedundantCopies(program, interferenceGraph, congruenceClasses, categories);
int[] classArray = congruenceClasses.pack(program.variableCount()); int[] classArray = congruenceClasses.pack(program.variableCount());
renameVariables(program, classArray); renameVariables(program, classArray);
int[] colors = new int[program.variableCount()]; int[] colors = new int[program.variableCount()];
@ -68,8 +72,13 @@ public class RegisterAllocator {
for (int cls : classArray) { for (int cls : classArray) {
maxClass = Math.max(maxClass, cls + 1); maxClass = Math.max(maxClass, cls + 1);
} }
var categories = variableCategoryProvider.getCategories(program, method);
String[] names = getVariableNames(program, debuggerFriendly); String[] names = getVariableNames(program, debuggerFriendly);
var newCategories = new Object[categories.length];
for (int i = 0; i < categories.length; ++i) {
var cls = classArray[i];
newCategories[cls] = categories[i];
}
categories = newCategories;
colorer.colorize(MutableGraphNode.toGraph(interferenceGraph), colors, categories, names); colorer.colorize(MutableGraphNode.toGraph(interferenceGraph), colors, categories, names);
int maxColor = 0; int maxColor = 0;
@ -129,7 +138,7 @@ public class RegisterAllocator {
} }
} }
private void insertPhiArgumentsCopies(Program program) { private void insertPhiArgumentsCopies(Program program, List<Object> categories) {
List<List<Incoming>> catchIncomingsByVariable = new ArrayList<>( List<List<Incoming>> catchIncomingsByVariable = new ArrayList<>(
Collections.nCopies(program.variableCount(), null)); Collections.nCopies(program.variableCount(), null));
@ -151,14 +160,14 @@ public class RegisterAllocator {
} }
catchIncomings.add(incoming); catchIncomings.add(incoming);
} else { } else {
insertCopy(incoming, blockMap); insertCopy(incoming, blockMap, categories);
incomingsToRepeat.add(incoming); incomingsToRepeat.add(incoming);
} }
} }
} }
for (Incoming incoming : incomingsToRepeat) { for (Incoming incoming : incomingsToRepeat) {
insertCopy(incoming, blockMap); insertCopy(incoming, blockMap, categories);
} }
} }
@ -167,7 +176,7 @@ public class RegisterAllocator {
for (BasicBlock block : program.getBasicBlocks()) { for (BasicBlock block : program.getBasicBlocks()) {
for (Phi phi : block.getPhis()) { for (Phi phi : block.getPhis()) {
addExceptionHandlingCopies(catchIncomingsByVariable, phi.getReceiver(), block, addExceptionHandlingCopies(catchIncomingsByVariable, phi.getReceiver(), block,
program, block.getFirstInstruction().getLocation(), nextInstructions); program, block.getFirstInstruction().getLocation(), nextInstructions, categories);
} }
if (!nextInstructions.isEmpty()) { if (!nextInstructions.isEmpty()) {
@ -180,7 +189,7 @@ public class RegisterAllocator {
Variable[] definedVariables = definitionExtractor.getDefinedVariables(); Variable[] definedVariables = definitionExtractor.getDefinedVariables();
for (Variable definedVariable : definedVariables) { for (Variable definedVariable : definedVariables) {
addExceptionHandlingCopies(catchIncomingsByVariable, definedVariable, block, addExceptionHandlingCopies(catchIncomingsByVariable, definedVariable, block,
program, instruction.getLocation(), nextInstructions); program, instruction.getLocation(), nextInstructions, categories);
} }
if (!nextInstructions.isEmpty()) { if (!nextInstructions.isEmpty()) {
@ -198,6 +207,7 @@ public class RegisterAllocator {
BasicBlock block = incoming.getSource(); BasicBlock block = incoming.getSource();
Variable copy = program.createVariable(); Variable copy = program.createVariable();
categories.add(categories.get(incoming.getPhi().getReceiver().getIndex()));
copy.setLabel(incoming.getPhi().getReceiver().getLabel()); copy.setLabel(incoming.getPhi().getReceiver().getLabel());
copy.setDebugName(incoming.getPhi().getReceiver().getDebugName()); copy.setDebugName(incoming.getPhi().getReceiver().getDebugName());
@ -213,7 +223,8 @@ public class RegisterAllocator {
} }
private void addExceptionHandlingCopies(List<List<Incoming>> catchIncomingsByVariable, Variable definedVariable, private void addExceptionHandlingCopies(List<List<Incoming>> catchIncomingsByVariable, Variable definedVariable,
BasicBlock block, Program program, TextLocation location, List<Instruction> nextInstructions) { BasicBlock block, Program program, TextLocation location, List<Instruction> nextInstructions,
List<Object> categories) {
if (definedVariable.getIndex() >= catchIncomingsByVariable.size()) { if (definedVariable.getIndex() >= catchIncomingsByVariable.size()) {
return; return;
} }
@ -228,6 +239,7 @@ public class RegisterAllocator {
Variable copy = program.createVariable(); Variable copy = program.createVariable();
copy.setLabel(incoming.getPhi().getReceiver().getLabel()); copy.setLabel(incoming.getPhi().getReceiver().getLabel());
copy.setDebugName(incoming.getPhi().getReceiver().getDebugName()); copy.setDebugName(incoming.getPhi().getReceiver().getDebugName());
categories.add(categories.get(incoming.getPhi().getReceiver().getIndex()));
AssignInstruction copyInstruction = new AssignInstruction(); AssignInstruction copyInstruction = new AssignInstruction();
copyInstruction.setReceiver(copy); copyInstruction.setReceiver(copy);
@ -242,11 +254,12 @@ public class RegisterAllocator {
} }
} }
private void insertCopy(Incoming incoming, Map<BasicBlock, BasicBlock> blockMap) { private void insertCopy(Incoming incoming, Map<BasicBlock, BasicBlock> blockMap, List<Object> categories) {
Phi phi = incoming.getPhi(); Phi phi = incoming.getPhi();
Program program = phi.getBasicBlock().getProgram(); Program program = phi.getBasicBlock().getProgram();
AssignInstruction copyInstruction = new AssignInstruction(); AssignInstruction copyInstruction = new AssignInstruction();
Variable firstCopy = program.createVariable(); Variable firstCopy = program.createVariable();
categories.add(categories.get(phi.getReceiver().getIndex()));
firstCopy.setLabel(phi.getReceiver().getLabel()); firstCopy.setLabel(phi.getReceiver().getLabel());
firstCopy.setDebugName(phi.getReceiver().getDebugName()); firstCopy.setDebugName(phi.getReceiver().getDebugName());
copyInstruction.setReceiver(firstCopy); copyInstruction.setReceiver(firstCopy);
@ -274,7 +287,7 @@ public class RegisterAllocator {
} }
private void removeRedundantCopies(Program program, List<MutableGraphNode> interferenceGraph, private void removeRedundantCopies(Program program, List<MutableGraphNode> interferenceGraph,
DisjointSet congruenceClasses) { DisjointSet congruenceClasses, Object[] categories) {
for (int i = 0; i < program.basicBlockCount(); ++i) { for (int i = 0; i < program.basicBlockCount(); ++i) {
BasicBlock block = program.basicBlockAt(i); BasicBlock block = program.basicBlockAt(i);
Instruction nextInsn; Instruction nextInsn;
@ -297,7 +310,8 @@ public class RegisterAllocator {
break; break;
} }
} }
if (!interfere) { if (!interfere && Objects.equals(categories[assignment.getReceiver().getIndex()],
categories[assignment.getAssignee().getIndex()])) {
int newClass = congruenceClasses.union(copyClass, origClass); int newClass = congruenceClasses.union(copyClass, origClass);
insn.delete(); insn.delete();
if (newClass == interferenceGraph.size()) { if (newClass == interferenceGraph.size()) {