Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions src/ir/type-updating.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,37 @@

namespace wasm {

GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {}

void GlobalTypeRewriter::update() {
mapTypes(rebuildTypes(getSortedTypes(getPrivatePredecessors())));
}

GlobalTypeRewriter::PredecessorGraph
GlobalTypeRewriter::getPrivatePredecessors() {
GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm)
: wasm(wasm), publicGroups(wasm.features) {
// Find the heap types that are not publicly observable. Even in a closed
// world scenario, don't modify public types because we assume that they may
// be reflected on or used for linking. Figure out where each private type
// will be located in the builder.
auto typeInfo = ModuleUtils::collectHeapTypeInfo(
typeInfo = ModuleUtils::collectHeapTypeInfo(
wasm,
ModuleUtils::TypeInclusion::UsedIRTypes,
ModuleUtils::VisibilityHandling::FindVisibility);

// Check if a type is private, by looking up its info.
std::unordered_set<RecGroup> seenGroups;
for (auto& [type, info] : typeInfo) {
if (info.visibility == ModuleUtils::Visibility::Public) {
auto group = type.getRecGroup();
if (seenGroups.insert(type.getRecGroup()).second) {
std::vector<HeapType> groupTypes(group.begin(), group.end());
publicGroups.insert(std::move(groupTypes));
}
}
}
}

void GlobalTypeRewriter::update() {
mapTypes(rebuildTypes(getSortedTypes(getPrivatePredecessors())));
}

GlobalTypeRewriter::PredecessorGraph
GlobalTypeRewriter::getPrivatePredecessors() {
// Check if a type is private, looking for its info (if there is none, it is
// not private).
auto isPublic = [&](HeapType type) {
auto it = typeInfo.find(type);
assert(it != typeInfo.end());
Expand Down Expand Up @@ -185,11 +198,8 @@ GlobalTypeRewriter::rebuildTypes(std::vector<HeapType> types) {
<< " at index " << err->index;
}
#endif
auto& newTypes = *buildResults;

// TODO: It is possible that the newly built rec group matches some public rec
// group. If that is the case, we need to try a different permutation of the
// types or add a brand type to distinguish the private types.
// Ensure the new types are different from any public rec group.
const auto& newTypes = publicGroups.insert(*buildResults);

// Map the old types to the new ones.
TypeMap oldToNewTypes;
Expand Down
10 changes: 10 additions & 0 deletions src/ir/type-updating.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
#define wasm_ir_type_updating_h

#include "ir/branch-utils.h"
#include "ir/module-utils.h"
#include "support/insert_ordered.h"
#include "wasm-traversal.h"
#include "wasm-type-shape.h"
#include "wasm-type.h"

namespace wasm {

Expand Down Expand Up @@ -348,6 +351,13 @@ class GlobalTypeRewriter {

Module& wasm;

// The module's types and their visibilities.
InsertOrderedMap<HeapType, ModuleUtils::HeapTypeInfo> typeInfo;

// The shapes of public rec groups, so we can be sure that the rewritten
// private types do not conflict with public types.
UniqueRecGroups publicGroups;

GlobalTypeRewriter(Module& wasm);
virtual ~GlobalTypeRewriter() {}

Expand Down
78 changes: 0 additions & 78 deletions src/passes/MinimizeRecGroups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,84 +100,6 @@ struct TypeSCCs
}
};

// After all their permutations with distinct shapes have been used, different
// groups with the same shapes must be differentiated by adding in a "brand"
// type. Even with a brand mixed in, we might run out of permutations with
// distinct shapes, in which case we need a new brand type. This iterator
// provides an infinite sequence of possible brand types, prioritizing those
// with the most compact encoding.
struct BrandTypeIterator {
static constexpr Index optionCount = 18;
static constexpr std::array<Field, optionCount> fieldOptions = {{
Field(Field::i8, Mutable),
Field(Field::i16, Mutable),
Field(Type::i32, Mutable),
Field(Type::i64, Mutable),
Field(Type::f32, Mutable),
Field(Type::f64, Mutable),
Field(Type(HeapType::any, Nullable), Mutable),
Field(Type(HeapType::func, Nullable), Mutable),
Field(Type(HeapType::ext, Nullable), Mutable),
Field(Type(HeapType::none, Nullable), Mutable),
Field(Type(HeapType::nofunc, Nullable), Mutable),
Field(Type(HeapType::noext, Nullable), Mutable),
Field(Type(HeapType::any, NonNullable), Mutable),
Field(Type(HeapType::func, NonNullable), Mutable),
Field(Type(HeapType::ext, NonNullable), Mutable),
Field(Type(HeapType::none, NonNullable), Mutable),
Field(Type(HeapType::nofunc, NonNullable), Mutable),
Field(Type(HeapType::noext, NonNullable), Mutable),
}};

struct FieldInfo {
uint8_t index = 0;
bool immutable = false;

operator Field() const {
auto field = fieldOptions[index];
if (immutable) {
field.mutable_ = Immutable;
}
return field;
}

bool advance() {
if (!immutable) {
immutable = true;
return true;
}
immutable = false;
index = (index + 1) % optionCount;
return index != 0;
}
};

bool useArray = false;
std::vector<FieldInfo> fields;

HeapType operator*() const {
if (useArray) {
return Array(fields[0]);
}
return Struct(std::vector<Field>(fields.begin(), fields.end()));
}

BrandTypeIterator& operator++() {
for (Index i = fields.size(); i > 0; --i) {
if (fields[i - 1].advance()) {
return *this;
}
}
if (useArray) {
useArray = false;
return *this;
}
fields.emplace_back();
useArray = fields.size() == 1;
return *this;
}
};

// Create an adjacency list with edges from supertype to subtype and from
// described type to descriptor.
std::vector<std::vector<Index>>
Expand Down
98 changes: 98 additions & 0 deletions src/wasm-type-shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#define wasm_wasm_type_shape_h

#include <functional>
#include <list>
#include <unordered_set>
#include <vector>

#include "wasm-features.h"
Expand Down Expand Up @@ -79,4 +81,100 @@ template<> class hash<wasm::RecGroupShape> {

} // namespace std

namespace wasm {

// Provides an infinite sequence of possible brand types, prioritizing those
// with the most compact encoding.
struct BrandTypeIterator {
static constexpr Index optionCount = 18;
static constexpr std::array<Field, optionCount> fieldOptions = {{
Field(Field::i8, Mutable),
Field(Field::i16, Mutable),
Field(Type::i32, Mutable),
Field(Type::i64, Mutable),
Field(Type::f32, Mutable),
Field(Type::f64, Mutable),
Field(Type(HeapType::any, Nullable), Mutable),
Field(Type(HeapType::func, Nullable), Mutable),
Field(Type(HeapType::ext, Nullable), Mutable),
Field(Type(HeapType::none, Nullable), Mutable),
Field(Type(HeapType::nofunc, Nullable), Mutable),
Field(Type(HeapType::noext, Nullable), Mutable),
Field(Type(HeapType::any, NonNullable), Mutable),
Field(Type(HeapType::func, NonNullable), Mutable),
Field(Type(HeapType::ext, NonNullable), Mutable),
Field(Type(HeapType::none, NonNullable), Mutable),
Field(Type(HeapType::nofunc, NonNullable), Mutable),
Field(Type(HeapType::noext, NonNullable), Mutable),
}};

struct FieldInfo {
uint8_t index = 0;
bool immutable = false;

operator Field() const {
auto field = fieldOptions[index];
if (immutable) {
field.mutable_ = Immutable;
}
return field;
}

bool advance() {
if (!immutable) {
immutable = true;
return true;
}
immutable = false;
index = (index + 1) % optionCount;
return index != 0;
}
};

bool useArray = false;
std::vector<FieldInfo> fields;

HeapType operator*() const {
if (useArray) {
return Array(fields[0]);
}
return Struct(std::vector<Field>(fields.begin(), fields.end()));
}

BrandTypeIterator& operator++() {
for (Index i = fields.size(); i > 0; --i) {
if (fields[i - 1].advance()) {
return *this;
}
}
if (useArray) {
useArray = false;
return *this;
}
fields.emplace_back();
useArray = fields.size() == 1;
return *this;
}
};

// A set of unique rec group shapes. Upon inserting a new group of types, if it
// has the same shape as a previously inserted group, the types will be rebuilt
// with an extra brand type at the end of the group that differentiates it from
// previous group.
struct UniqueRecGroups {
std::list<std::vector<HeapType>> groups;
std::unordered_set<RecGroupShape> shapes;

FeatureSet features;

UniqueRecGroups(FeatureSet features) : features(features) {}

// Insert a rec group. If it is already unique, return the original types.
// Otherwise rebuild the group make it unique and return the rebuilt types,
// including the brand.
const std::vector<HeapType>& insert(std::vector<HeapType> group);
};

} // namespace wasm

#endif // wasm_wasm_type_shape_h
35 changes: 35 additions & 0 deletions src/wasm/wasm-type-shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,41 @@ bool ComparableRecGroupShape::operator>(const RecGroupShape& other) const {
return GT == compareComparable(*this, other);
}

const std::vector<HeapType>&
UniqueRecGroups::insert(std::vector<HeapType> types) {
auto& group = *groups.emplace(groups.end(), std::move(types));
if (shapes.emplace(RecGroupShape(group, features)).second) {
// The types are already unique.
return group;
}
// There is a conflict. Find a brand that makes the group unique.
BrandTypeIterator brand;
group.push_back(*brand);
while (!shapes.emplace(RecGroupShape(group, features)).second) {
group.back() = *++brand;
}
// Rebuild the rec group to include the brand. Map the old types (excluding
// the brand) to their corresponding new types to preserve recursions within
// the group.
Index size = group.size();
TypeBuilder builder(size);
std::unordered_map<HeapType, HeapType> newTypes;
for (Index i = 0; i < size - 1; ++i) {
newTypes[group[i]] = builder[i];
}
for (Index i = 0; i < size; ++i) {
builder[i].copy(group[i], [&](HeapType type) {
if (auto newType = newTypes.find(type); newType != newTypes.end()) {
return newType->second;
}
return type;
});
}
builder.createRecGroup(0, size);
group = *builder.build();
return group;
}

} // namespace wasm

namespace std {
Expand Down
59 changes: 59 additions & 0 deletions test/lit/passes/signature-pruning-public-collision.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited.
;; RUN: wasm-opt %s -all --closed-world --signature-pruning --fuzz-exec -S -o - | filecheck %s

(module
;; CHECK: (type $public (func))

;; CHECK: (rec
;; CHECK-NEXT: (type $private (func))

;; CHECK: (type $2 (struct))

;; CHECK: (type $test (func (result i32)))
(type $test (func (result i32)))

(type $public (func))

;; After signature pruning this will be (func), which is the same as $public.
;; We must make sure we keep $private a distinct type.
(type $private (func (param i32)))

;; CHECK: (import "" "" (func $public (type $public)))
(import "" "" (func $public (type $public)))

;; CHECK: (elem declare func $public)

;; CHECK: (export "test" (func $test))

;; CHECK: (func $private (type $private)
;; CHECK-NEXT: (local $0 i32)
;; CHECK-NEXT: (nop)
;; CHECK-NEXT: )
(func $private (type $private) (param $unused i32)
(nop)
)

;; CHECK: (func $test (type $test) (result i32)
;; CHECK-NEXT: (local $0 funcref)
;; CHECK-NEXT: (ref.test (ref $private)
;; CHECK-NEXT: (select (result funcref)
;; CHECK-NEXT: (ref.func $public)
;; CHECK-NEXT: (local.get $0)
;; CHECK-NEXT: (i32.const 1)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $test (export "test") (type $test) (result i32)
(local funcref)
;; Test that $private and $public are separate types. This should return 0.
(ref.test (ref $private)
;; Use select to prevent the ref.test from being optimized in
;; finalization.
(select (result funcref)
(ref.func $public)
(local.get 0)
(i32.const 1)
)
)
)
)
Loading