35#ifndef LLVM_ANALYSIS_IR2VEC_H
36#define LLVM_ANALYSIS_IR2VEC_H
89 std::vector<double> Data;
93 Embedding(
const std::vector<double> &V) : Data(V) {}
95 Embedding(std::initializer_list<double> IL) : Data(IL) {}
100 size_t size()
const {
return Data.size(); }
101 bool empty()
const {
return Data.empty(); }
104 assert(Itr < Data.size() &&
"Index out of bounds");
109 assert(Itr < Data.size() &&
"Index out of bounds");
113 using iterator =
typename std::vector<double>::iterator;
123 const std::vector<double> &
getData()
const {
return Data; }
140 double Tolerance = 1e-4)
const;
154 std::vector<std::vector<Embedding>> Sections;
159 size_t TotalSize = 0;
160 unsigned Dimension = 0;
167 VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
176 size_t size()
const {
return TotalSize; }
180 return static_cast<unsigned>(Sections.size());
184 const std::vector<Embedding> &
operator[](
unsigned SectionId)
const {
185 assert(SectionId < Sections.size() &&
"Invalid section ID");
186 return Sections[SectionId];
193 bool isValid()
const {
return TotalSize > 0; }
198 unsigned SectionId = 0;
199 size_t LocalIndex = 0;
204 : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
222 VocabMap &TargetVocab,
unsigned &Dim);
261 enum class Section :
unsigned {
272 static constexpr unsigned NumICmpPredicates =
275 static constexpr unsigned NumFCmpPredicates =
307#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
308#include "llvm/IR/Instruction.def"
309#undef LAST_OTHER_INST
311 static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
319 NumICmpPredicates + NumFCmpPredicates;
331 return Storage.size() == NumCanonicalEntries;
336 return Storage.getDimension();
348 return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(
TypeID));
353 unsigned Index =
static_cast<unsigned>(Kind);
355 return OperandKindNames[Index];
366 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
372 return MaxOpcodes +
static_cast<unsigned>(getCanonicalTypeID(
TypeID));
378 return OperandBaseOffset + Index;
382 return PredicateBaseOffset + getPredicateLocalIndex(
P);
387 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
388 return Storage[
static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
393 unsigned LocalIndex =
static_cast<unsigned>(getCanonicalTypeID(
TypeID));
394 return Storage[
static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
398 unsigned LocalIndex =
static_cast<unsigned>(
getOperandKind(&Arg));
400 return Storage[
static_cast<unsigned>(Section::Operands)][LocalIndex];
404 unsigned LocalIndex = getPredicateLocalIndex(
P);
405 return Storage[
static_cast<unsigned>(Section::Predicates)][LocalIndex];
413 return Storage.begin();
420 return Storage.end();
434 ModuleAnalysisManager::Invalidator &Inv)
const;
437 constexpr static unsigned NumCanonicalEntries =
441 constexpr static unsigned OperandBaseOffset =
443 constexpr static unsigned PredicateBaseOffset =
452 "FloatTy",
"VoidTy",
"LabelTy",
"MetadataTy",
453 "VectorTy",
"TokenTy",
"IntegerTy",
"FunctionTy",
454 "PointerTy",
"StructTy",
"ArrayTy",
"UnknownTy"};
455 static_assert(std::size(CanonicalTypeNames) ==
457 "CanonicalTypeNames array size must match MaxCanonicalType");
460 static constexpr StringLiteral OperandKindNames[] = {
"Function",
"Pointer",
461 "Constant",
"Variable"};
462 static_assert(std::size(OperandKindNames) ==
464 "OperandKindNames array size must match MaxOperandKind");
468 static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
491 static_assert(TypeIDMapping.size() ==
MaxTypeIDs,
492 "TypeIDMapping must cover all Type::TypeID values");
497 unsigned Index =
static_cast<unsigned>(CType);
499 return CanonicalTypeNames[
Index];
506 return TypeIDMapping[
Index];
514 return getPredicateFromLocalIndex(Index);
556 LLVM_ABI static std::unique_ptr<Embedder>
583 void computeEmbeddings(
const BasicBlock &BB)
const override;
595 void computeEmbeddings(
const BasicBlock &BB)
const override;
608 using VocabMap = std::map<std::string, ir2vec::Embedding>;
609 std::optional<ir2vec::VocabStorage> Vocab;
611 Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
613 void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file defines the DenseMap class.
Provides ErrorOr<T> smart pointer.
This header defines various interfaces for pass management in LLVM.
This file supports working with JSON data.
ModuleAnalysisManager MAM
static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))
LLVM Basic Block Representation.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Lightweight error class with error context and mandatory checking.
IR2VecPrinterPass(raw_ostream &OS)
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
This analysis provides the vocabulary for IR2Vec.
IR2VecVocabAnalysis()=default
ir2vec::Vocabulary Result
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
LLVM_ABI IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
static LLVM_ABI AnalysisKey Key
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
IR2VecVocabPrinterPass(raw_ostream &OS)
This is an important class for using LLVM in a threaded context.
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...
StringRef - Represent a constant reference to a string, i.e.
TypeID
Definitions of all of the base types for the Type system.
LLVM Value Representation.
LLVM_ABI const Embedding & getBBVector(const BasicBlock &BB) const
Returns the embedding for a given basic block in the function F if it has been computed.
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
LLVM_ABI const BBEmbeddingsMap & getBBVecMap() const
Returns a map containing basic block and the corresponding embeddings for the function F if it has be...
void computeEmbeddings() const
Function to compute embeddings.
virtual ~Embedder()=default
LLVM_ABI const InstEmbeddingsMap & getInstVecMap() const
Returns a map containing instructions and the corresponding embeddings for the function F if it has b...
const float OpcWeight
Weights for different entities (like opcode, arguments, types) in the IR instructions to generate the...
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
virtual void computeEmbeddings(const BasicBlock &BB) const =0
Function to compute the embedding for a given basic block.
LLVM_ABI const Embedding & getFunctionVector() const
Computes and returns the embedding for the current function.
InstEmbeddingsMap InstVecMap
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
Iterator support for section-based access.
const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)
LLVM_ABI bool operator!=(const const_iterator &Other) const
LLVM_ABI const_iterator & operator++()
LLVM_ABI const Embedding & operator*() const
LLVM_ABI bool operator==(const const_iterator &Other) const
Generic storage class for section-based vocabularies.
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
VocabStorage & operator=(VocabStorage &&)=default
const_iterator end() const
unsigned getNumSections() const
Get number of sections.
VocabStorage()
Default constructor creates empty storage (invalid state)
VocabStorage & operator=(const VocabStorage &)=delete
unsigned getDimension() const
Get vocabulary dimension.
size_t size() const
Get total number of entries across all sections.
const_iterator begin() const
bool isValid() const
Check if vocabulary is valid (has data)
VocabStorage(VocabStorage &&)=default
std::map< std::string, Embedding > VocabMap
const std::vector< Embedding > & operator[](unsigned SectionId) const
Section-based access: Storage[sectionId][localIndex].
VocabStorage(const VocabStorage &)=delete
Class for storing and accessing the IR2Vec vocabulary.
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
const_iterator begin() const
LLVM_ABI unsigned getDimension() const
Vocabulary(Vocabulary &&)=default
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
static LLVM_ABI unsigned getIndex(CmpInst::Predicate P)
Vocabulary & operator=(const Vocabulary &)=delete
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
static constexpr unsigned MaxCanonicalTypeIDs
LLVM_ABI const ir2vec::Embedding & operator[](CmpInst::Predicate P) const
static constexpr unsigned MaxOperandKinds
Vocabulary(const Vocabulary &)=delete
const_iterator cbegin() const
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
static constexpr size_t getCanonicalSize()
Total number of entries (opcodes + canonicalized types + operand kinds + predicates)
static LLVM_ABI unsigned getIndex(const Value &Op)
static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)
Function to get vocabulary key for a given predicate.
static constexpr unsigned MaxTypeIDs
LLVM_ABI Vocabulary(VocabStorage &&Storage)
LLVM_ABI const ir2vec::Embedding & operator[](Type::TypeID TypeID) const
static LLVM_ABI unsigned getIndex(Type::TypeID TypeID)
const_iterator end() const
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
static LLVM_ABI StringRef getVocabKeyForTypeID(Type::TypeID TypeID)
Function to get vocabulary key for a given TypeID.
VocabStorage::const_iterator const_iterator
Const Iterator type aliases.
const_iterator cend() const
static LLVM_ABI unsigned getIndex(unsigned Opcode)
Functions to return flat index.
LLVM_ABI bool isValid() const
Vocabulary & operator=(Vocabulary &&Other)=delete
LLVM_ABI const ir2vec::Embedding & operator[](unsigned Opcode) const
Accessors to get the embedding for a given entity.
static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
static constexpr unsigned MaxPredicateKinds
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
LLVM_ABI const ir2vec::Embedding & operator[](const Value &Arg) const
A Value is an JSON value of unknown type.
This class implements an extremely fast bulk output stream that can only output to a stream.
DenseMap< const Instruction *, Embedding > InstEmbeddingsMap
LLVM_ABI cl::opt< float > ArgWeight
DenseMap< const BasicBlock *, Embedding > BBEmbeddingsMap
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
llvm::cl::OptionCategory IR2VecCategory
This is an optimization pass for GlobalISel generic memory operations.
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Implement std::hash so that hash_code can be used in STL containers.
A CRTP mix-in that provides informational APIs needed for analysis passes.
A special type used by analysis passes to provide an address that identifies that particular analysis...
A CRTP mix-in to automatically provide informational APIs needed for passes.
Embedding is a datatype that wraps std::vector<double>.
const_iterator end() const
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
const_iterator cbegin() const
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
LLVM_ABI Embedding operator-(const Embedding &RHS) const
const std::vector< double > & getData() const
typename std::vector< double >::const_iterator const_iterator
Embedding(size_t Size, double InitialValue)
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
const_iterator cend() const
LLVM_ABI Embedding operator*(double Factor) const
LLVM_ABI Embedding & operator*=(double Factor)
Embedding(std::initializer_list< double > IL)
Embedding(const std::vector< double > &V)
LLVM_ABI Embedding operator+(const Embedding &RHS) const
typename std::vector< double >::iterator iterator
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
Embedding(std::vector< double > &&V)
const double & operator[](size_t Itr) const
LLVM_ABI void print(raw_ostream &OS) const
const_iterator begin() const
double & operator[](size_t Itr)