30using namespace ir2vec;
32#define DEBUG_TYPE "ir2vec"
35 "Number of lookups to entities not present in the vocabulary");
47 cl::desc(
"Weight for opcode embeddings"),
50 cl::desc(
"Weight for type embeddings"),
53 cl::desc(
"Weight for argument embeddings"),
58 "Generate symbolic embeddings"),
60 "Generate flow-aware embeddings")),
75 std::vector<double> TempOut;
87 assert(this->
size() == RHS.
size() &&
"Vectors must have the same dimension");
88 std::transform(this->
begin(), this->
end(), RHS.
begin(), this->begin(),
100 assert(this->
size() == RHS.
size() &&
"Vectors must have the same dimension");
101 std::transform(this->
begin(), this->
end(), RHS.
begin(), this->begin(),
102 std::minus<double>());
114 [Factor](
double Elem) {
return Elem * Factor; });
125 assert(this->
size() == Src.
size() &&
"Vectors must have the same dimension");
126 for (
size_t Itr = 0; Itr < this->
size(); ++Itr)
127 (*
this)[Itr] += Src[Itr] * Factor;
132 double Tolerance)
const {
133 assert(this->
size() == RHS.
size() &&
"Vectors must have the same dimension");
134 for (
size_t Itr = 0; Itr < this->
size(); ++Itr)
135 if (std::abs((*
this)[Itr] -
RHS[Itr]) > Tolerance) {
136 LLVM_DEBUG(
errs() <<
"Embedding mismatch at index " << Itr <<
": "
137 << (*
this)[Itr] <<
" vs " <<
RHS[Itr]
138 <<
"; Tolerance: " << Tolerance <<
"\n");
146 for (
const auto &Elem : Data)
147 OS <<
" " <<
format(
"%.2f", Elem) <<
" ";
156 :
F(
F), Vocab(Vocab), Dimension(Vocab.getDimension()),
164 return std::make_unique<SymbolicEmbedder>(
F,
Vocab);
166 return std::make_unique<FlowAwareEmbedder>(
F,
Vocab);
215 for (
const auto &
Op :
I.operands())
218 Vocab[
I.getOpcode()] +
Vocab[
I.getType()->getTypeID()] + ArgEmb;
220 BBVector += InstVector;
233 for (
const auto &
Op :
I.operands()) {
235 if (
const auto *DefInst = dyn_cast<Instruction>(
Op)) {
238 "Instruction should have been processed before its operands");
239 ArgEmb += DefIt->second;
244 LLVM_DEBUG(
errs() <<
"Using embedding from vocabulary for operand: "
252 Vocab[
I.getOpcode()] +
Vocab[
I.getType()->getTypeID()] + ArgEmb;
254 BBVector += InstVector;
267 return Vocab.size() == NumCanonicalEntries && Valid;
271 assert(Valid &&
"IR2Vec Vocabulary is invalid");
272 return Vocab[0].size();
276 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
282 return MaxOpcodes +
static_cast<unsigned>(getCanonicalTypeID(
TypeID));
304 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
305#define HANDLE_INST(NUM, OPCODE, CLASS) \
306 if (Opcode == NUM) { \
309#include "llvm/IR/Instruction.def"
311 return "UnknownOpcode";
314StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
315 unsigned Index =
static_cast<unsigned>(CType);
317 return CanonicalTypeNames[Index];
322 unsigned Index =
static_cast<unsigned>(
TypeID);
324 return TypeIDMapping[Index];
328 return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(
TypeID));
332 unsigned Index =
static_cast<unsigned>(Kind);
334 return OperandKindNames[Index];
339 if (isa<Function>(
Op))
341 if (isa<PointerType>(
Op->getType()))
343 if (isa<Constant>(
Op))
349 assert(Pos < NumCanonicalEntries &&
"Position out of bounds in vocabulary");
351 if (Pos < MaxOpcodes)
355 return getVocabKeyForCanonicalTypeID(
366 return !(PAC.preservedWhenStateless());
370 VocabVector DummyVocab;
371 DummyVocab.reserve(NumCanonicalEntries);
372 float DummyVal = 0.1f;
375 for ([[maybe_unused]]
unsigned _ :
378 DummyVocab.push_back(
Embedding(Dim, DummyVal));
388Error IR2VecVocabAnalysis::parseVocabSection(
395 "JSON root is not an object");
400 "Missing '" + std::string(Key) +
401 "' section in vocabulary file");
404 "Unable to parse '" + std::string(Key) +
405 "' section from vocabulary");
407 Dim = TargetVocab.begin()->second.size();
410 "Dimension of '" + std::string(Key) +
411 "' section of the vocabulary is zero");
413 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
414 [Dim](
const std::pair<StringRef, Embedding> &Entry) {
415 return Entry.second.size() == Dim;
419 "All vectors in the '" + std::string(Key) +
420 "' section of the vocabulary are not of the same dimension");
427Error IR2VecVocabAnalysis::readVocabulary() {
432 auto Content = BufOrError.get()->getBuffer();
435 if (!ParsedVocabValue)
438 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
440 parseVocabSection(
"Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
444 parseVocabSection(
"Types", *ParsedVocabValue, TypeVocab, TypeDim))
448 parseVocabSection(
"Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
451 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
453 "Vocabulary sections have different dimensions");
458void IR2VecVocabAnalysis::generateNumMappedVocab() {
463 auto handleMissingEntity = [](
const std::string &Val) {
465 <<
" is not in vocabulary, using zero vector; This "
466 "would result in an error in future.\n");
470 unsigned Dim = OpcVocab.begin()->second.size();
471 assert(Dim > 0 &&
"Vocabulary dimension must be greater than zero");
474 std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
476 NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
477 for (
unsigned Opcode :
seq(0u, Vocabulary::MaxOpcodes)) {
479 auto It = OpcVocab.find(VocabKey.
str());
480 if (It != OpcVocab.end())
481 NumericOpcodeEmbeddings[Opcode] = It->second;
483 handleMissingEntity(VocabKey.
str());
485 Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
486 NumericOpcodeEmbeddings.end());
493 StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
495 if (
auto It = TypeVocab.find(VocabKey.
str()); It != TypeVocab.end()) {
496 NumericTypeEmbeddings[CTypeID] = It->second;
499 handleMissingEntity(VocabKey.
str());
501 Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),
502 NumericTypeEmbeddings.end());
511 auto It = ArgVocab.find(VocabKey.
str());
512 if (It != ArgVocab.end()) {
513 NumericArgEmbeddings[OpKind] = It->second;
516 handleMissingEntity(VocabKey.
str());
518 Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
519 NumericArgEmbeddings.end());
536 auto Ctx = &M.getContext();
544 Ctx->
emitError(
"IR2Vec vocabulary file path not specified; You may need to "
545 "set it using --ir2vec-vocab-path");
548 if (
auto Err = readVocabulary()) {
549 emitError(std::move(Err), *Ctx);
554 auto scaleVocabSection = [](VocabMap &Vocab,
double Weight) {
555 for (
auto &Entry : Vocab)
556 Entry.second *= Weight;
563 generateNumMappedVocab();
580 OS <<
"Error creating IR2Vec embeddings \n";
584 OS <<
"IR2Vec embeddings for function " <<
F.getName() <<
":\n";
585 OS <<
"Function vector: ";
586 Emb->getFunctionVector().print(OS);
588 OS <<
"Basic block vectors:\n";
589 const auto &BBMap = Emb->getBBVecMap();
591 auto It = BBMap.find(&BB);
592 if (It != BBMap.end()) {
593 OS <<
"Basic block: " << BB.
getName() <<
":\n";
594 It->second.print(OS);
598 OS <<
"Instruction vectors:\n";
599 const auto &InstMap = Emb->getInstVecMap();
602 auto It = InstMap.find(&
I);
603 if (It != InstMap.end()) {
604 OS <<
"Instruction: ";
606 It->second.print(OS);
617 assert(IR2VecVocabulary.isValid() &&
"IR2Vec Vocabulary is invalid");
621 for (
const auto &Entry : IR2VecVocabulary) {
622 OS <<
"Key: " << IR2VecVocabulary.getStringKey(Pos++) <<
": ";
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
ModuleAnalysisManager MAM
Provides some synthesis utilities to produce sequences of values.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
API to communicate dependencies between analyses during invalidation.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
LLVM Basic Block Representation.
LLVM_ABI iterator_range< filter_iterator< BasicBlock::const_iterator, std::function< bool(const Instruction &)> > > instructionsWithoutDebug(bool SkipPseudoOp=true) const
Return a const iterator range over the instructions in the block, skipping any debug instructions.
This class represents an Operation in the Expression.
iterator find(const_arg_type_t< KeyT > Val)
Base class for error info classes.
virtual std::string message() const
Return the error message as a string.
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Tagged union holding either a T or a Error.
Error takeError()
Take ownership of the stored error.
LLVM_ABI bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
This analysis provides the vocabulary for IR2Vec.
IR2VecVocabAnalysis()=default
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
static LLVM_ABI AnalysisKey Key
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
This is an important class for using LLVM in a threaded context.
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
StringRef - Represent a constant reference to a string, i.e.
std::string str() const
str - Get the contents as an std::string.
TypeID
Definitions of all of the base types for the Type system.
LLVM Value Representation.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
LLVM_ABI const Embedding & getBBVector(const BasicBlock &BB) const
Returns the embedding for a given basic block in the function F if it has been computed.
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
LLVM_ABI const BBEmbeddingsMap & getBBVecMap() const
Returns a map containing basic block and the corresponding embeddings for the function F if it has be...
void computeEmbeddings() const
Function to compute embeddings.
LLVM_ABI const InstEmbeddingsMap & getInstVecMap() const
Returns a map containing instructions and the corresponding embeddings for the function F if it has b...
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
LLVM_ABI const Embedding & getFunctionVector() const
Computes and returns the embedding for the current function.
InstEmbeddingsMap InstVecMap
Class for storing and accessing the IR2Vec vocabulary.
static LLVM_ABI unsigned getSlotIndex(unsigned Opcode)
Functions to return the slot index or position of a given Opcode, TypeID, or OperandKind in the vocab...
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
LLVM_ABI const ir2vec::Embedding & operator[](unsigned Opcode) const
Accessors to get the embedding for a given entity.
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
LLVM_ABI bool isValid() const
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
static constexpr unsigned MaxCanonicalTypeIDs
static LLVM_ABI VocabVector createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
static constexpr unsigned MaxOperandKinds
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
static constexpr unsigned MaxTypeIDs
static LLVM_ABI StringRef getVocabKeyForTypeID(Type::TypeID TypeID)
Function to get vocabulary key for a given TypeID.
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
LLVM_ABI unsigned getDimension() const
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
An Object is a JSON object, which maps strings to heterogenous JSON values.
LLVM_ABI Value * get(StringRef K)
The root is the trivial Path to the root value.
A "cursor" marking a position within a Value.
A Value is an JSON value of unknown type.
const json::Object * getAsObject() const
This class implements an extremely fast bulk output stream that can only output to a stream.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
static cl::opt< std::string > VocabFile("ir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), cl::cat(IR2VecCategory))
LLVM_ABI cl::opt< float > ArgWeight
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
llvm::cl::OptionCategory IR2VecCategory
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
bool fromJSON(const Value &E, std::string &Out, Path P)
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
void handleAllErrors(Error E, HandlerTs &&... Handlers)
Behaves the same as handleErrors, except that by contract all errors must be handled by the given han...
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Implement std::hash so that hash_code can be used in STL containers.
A special type used by analysis passes to provide an address that identifies that particular analysis...
Embedding is a datatype that wraps std::vector<double>.
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
LLVM_ABI Embedding operator-(const Embedding &RHS) const
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
LLVM_ABI Embedding operator*(double Factor) const
LLVM_ABI Embedding & operator*=(double Factor)
LLVM_ABI Embedding operator+(const Embedding &RHS) const
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
LLVM_ABI void print(raw_ostream &OS) const