LLVM 22.0.0git
IR2Vec.cpp
Go to the documentation of this file.
1//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2Vec algorithm.
11///
12//===----------------------------------------------------------------------===//
13
15
17#include "llvm/ADT/Sequence.h"
18#include "llvm/ADT/Statistic.h"
19#include "llvm/IR/CFG.h"
20#include "llvm/IR/Module.h"
21#include "llvm/IR/PassManager.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/Errc.h"
24#include "llvm/Support/Error.h"
26#include "llvm/Support/Format.h"
28
29using namespace llvm;
30using namespace ir2vec;
31
32#define DEBUG_TYPE "ir2vec"
33
34STATISTIC(VocabMissCounter,
35 "Number of lookups to entities not present in the vocabulary");
36
37namespace llvm {
38namespace ir2vec {
40
41// FIXME: Use a default vocab when not specified
43 VocabFile("ir2vec-vocab-path", cl::Optional,
44 cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
46cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
47 cl::desc("Weight for opcode embeddings"),
49cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
50 cl::desc("Weight for type embeddings"),
52cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
53 cl::desc("Weight for argument embeddings"),
56 "ir2vec-kind", cl::Optional,
58 "Generate symbolic embeddings"),
60 "Generate flow-aware embeddings")),
61 cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
63
64} // namespace ir2vec
65} // namespace llvm
66
68
69// ==----------------------------------------------------------------------===//
70// Local helper functions
71//===----------------------------------------------------------------------===//
72namespace llvm::json {
73inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
75 std::vector<double> TempOut;
76 if (!llvm::json::fromJSON(E, TempOut, P))
77 return false;
78 Out = Embedding(std::move(TempOut));
79 return true;
80}
81} // namespace llvm::json
82
83// ==----------------------------------------------------------------------===//
84// Embedding
85//===----------------------------------------------------------------------===//
87 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
88 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
89 std::plus<double>());
90 return *this;
91}
92
94 Embedding Result(*this);
95 Result += RHS;
96 return Result;
97}
98
100 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
101 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
102 std::minus<double>());
103 return *this;
104}
105
107 Embedding Result(*this);
108 Result -= RHS;
109 return Result;
110}
111
113 std::transform(this->begin(), this->end(), this->begin(),
114 [Factor](double Elem) { return Elem * Factor; });
115 return *this;
116}
117
118Embedding Embedding::operator*(double Factor) const {
119 Embedding Result(*this);
120 Result *= Factor;
121 return Result;
122}
123
124Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
125 assert(this->size() == Src.size() && "Vectors must have the same dimension");
126 for (size_t Itr = 0; Itr < this->size(); ++Itr)
127 (*this)[Itr] += Src[Itr] * Factor;
128 return *this;
129}
130
132 double Tolerance) const {
133 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
134 for (size_t Itr = 0; Itr < this->size(); ++Itr)
135 if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
136 LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
137 << (*this)[Itr] << " vs " << RHS[Itr]
138 << "; Tolerance: " << Tolerance << "\n");
139 return false;
140 }
141 return true;
142}
143
145 OS << " [";
146 for (const auto &Elem : Data)
147 OS << " " << format("%.2f", Elem) << " ";
148 OS << "]\n";
149}
150
151// ==----------------------------------------------------------------------===//
152// Embedder and its subclasses
153//===----------------------------------------------------------------------===//
154
159
160std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
161 const Vocabulary &Vocab) {
162 switch (Mode) {
164 return std::make_unique<SymbolicEmbedder>(F, Vocab);
166 return std::make_unique<FlowAwareEmbedder>(F, Vocab);
167 }
168 return nullptr;
169}
170
172 if (InstVecMap.empty())
174 return InstVecMap;
175}
176
178 if (BBVecMap.empty())
180 return BBVecMap;
181}
182
184 auto It = BBVecMap.find(&BB);
185 if (It != BBVecMap.end())
186 return It->second;
188 return BBVecMap[&BB];
189}
190
192 // Currently, we always (re)compute the embeddings for the function.
193 // This is cheaper than caching the vector.
195 return FuncVector;
196}
197
199 if (F.isDeclaration())
200 return;
201
202 // Consider only the basic blocks that are reachable from entry
203 for (const BasicBlock *BB : depth_first(&F)) {
205 FuncVector += BBVecMap[BB];
206 }
207}
208
210 Embedding BBVector(Dimension, 0);
211
212 // We consider only the non-debug and non-pseudo instructions
213 for (const auto &I : BB.instructionsWithoutDebug()) {
214 Embedding ArgEmb(Dimension, 0);
215 for (const auto &Op : I.operands())
216 ArgEmb += Vocab[*Op];
217 auto InstVector =
218 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
219 InstVecMap[&I] = InstVector;
220 BBVector += InstVector;
221 }
222 BBVecMap[&BB] = BBVector;
223}
224
226 Embedding BBVector(Dimension, 0);
227
228 // We consider only the non-debug and non-pseudo instructions
229 for (const auto &I : BB.instructionsWithoutDebug()) {
230 // TODO: Handle call instructions differently.
231 // For now, we treat them like other instructions
232 Embedding ArgEmb(Dimension, 0);
233 for (const auto &Op : I.operands()) {
234 // If the operand is defined elsewhere, we use its embedding
235 if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
236 auto DefIt = InstVecMap.find(DefInst);
237 assert(DefIt != InstVecMap.end() &&
238 "Instruction should have been processed before its operands");
239 ArgEmb += DefIt->second;
240 continue;
241 }
242 // If the operand is not defined by an instruction, we use the vocabulary
243 else {
244 LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
245 << *Op << "=" << Vocab[*Op][0] << "\n");
246 ArgEmb += Vocab[*Op];
247 }
248 }
249 // Create the instruction vector by combining opcode, type, and arguments
250 // embeddings
251 auto InstVector =
252 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
253 InstVecMap[&I] = InstVector;
254 BBVector += InstVector;
255 }
256 BBVecMap[&BB] = BBVector;
257}
258
259// ==----------------------------------------------------------------------===//
260// Vocabulary
261//===----------------------------------------------------------------------===//
262
263unsigned Vocabulary::getDimension() const {
264 assert(isValid() && "IR2Vec Vocabulary is invalid");
265 return Vocab[0].size();
266}
267
268unsigned Vocabulary::getSlotIndex(unsigned Opcode) {
269 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
270 return Opcode - 1; // Convert to zero-based index
271}
272
274 assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
275 return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
276}
277
279 unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
280 assert(Index < MaxOperandKinds && "Invalid OperandKind");
281 return MaxOpcodes + MaxCanonicalTypeIDs + Index;
282}
283
284const Embedding &Vocabulary::operator[](unsigned Opcode) const {
285 return Vocab[getSlotIndex(Opcode)];
286}
287
289 return Vocab[getSlotIndex(TypeID)];
290}
291
293 return Vocab[getSlotIndex(Arg)];
294}
295
297 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
298#define HANDLE_INST(NUM, OPCODE, CLASS) \
299 if (Opcode == NUM) { \
300 return #OPCODE; \
301 }
302#include "llvm/IR/Instruction.def"
303#undef HANDLE_INST
304 return "UnknownOpcode";
305}
306
307StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
308 unsigned Index = static_cast<unsigned>(CType);
309 assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
310 return CanonicalTypeNames[Index];
311}
312
314Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) {
315 unsigned Index = static_cast<unsigned>(TypeID);
316 assert(Index < MaxTypeIDs && "Invalid TypeID");
317 return TypeIDMapping[Index];
318}
319
321 return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
322}
323
325 unsigned Index = static_cast<unsigned>(Kind);
326 assert(Index < MaxOperandKinds && "Invalid OperandKind");
327 return OperandKindNames[Index];
328}
329
330// Helper function to classify an operand into OperandKind
340
342 assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
343 // Opcode
344 if (Pos < MaxOpcodes)
345 return getVocabKeyForOpcode(Pos + 1);
346 // Type
347 if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
348 return getVocabKeyForCanonicalTypeID(
349 static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
350 // Operand
352 static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
353}
354
355// For now, assume vocabulary is stable unless explicitly invalidated.
357 ModuleAnalysisManager::Invalidator &Inv) const {
358 auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
359 return !(PAC.preservedWhenStateless());
360}
361
362Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
363 VocabVector DummyVocab;
364 DummyVocab.reserve(NumCanonicalEntries);
365 float DummyVal = 0.1f;
366 // Create a dummy vocabulary with entries for all opcodes, types, and
367 // operands
368 for ([[maybe_unused]] unsigned _ :
369 seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
371 DummyVocab.push_back(Embedding(Dim, DummyVal));
372 DummyVal += 0.1f;
373 }
374 return DummyVocab;
375}
376
377// ==----------------------------------------------------------------------===//
378// IR2VecVocabAnalysis
379//===----------------------------------------------------------------------===//
380
381Error IR2VecVocabAnalysis::parseVocabSection(
382 StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,
383 unsigned &Dim) {
384 json::Path::Root Path("");
385 const json::Object *RootObj = ParsedVocabValue.getAsObject();
386 if (!RootObj)
388 "JSON root is not an object");
389
390 const json::Value *SectionValue = RootObj->get(Key);
391 if (!SectionValue)
393 "Missing '" + std::string(Key) +
394 "' section in vocabulary file");
395 if (!json::fromJSON(*SectionValue, TargetVocab, Path))
397 "Unable to parse '" + std::string(Key) +
398 "' section from vocabulary");
399
400 Dim = TargetVocab.begin()->second.size();
401 if (Dim == 0)
403 "Dimension of '" + std::string(Key) +
404 "' section of the vocabulary is zero");
405
406 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
407 [Dim](const std::pair<StringRef, Embedding> &Entry) {
408 return Entry.second.size() == Dim;
409 }))
410 return createStringError(
412 "All vectors in the '" + std::string(Key) +
413 "' section of the vocabulary are not of the same dimension");
414
415 return Error::success();
416}
417
418// FIXME: Make this optional. We can avoid file reads
419// by auto-generating a default vocabulary during the build time.
420Error IR2VecVocabAnalysis::readVocabulary() {
421 auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
422 if (!BufOrError)
423 return createFileError(VocabFile, BufOrError.getError());
424
425 auto Content = BufOrError.get()->getBuffer();
426
427 Expected<json::Value> ParsedVocabValue = json::parse(Content);
428 if (!ParsedVocabValue)
429 return ParsedVocabValue.takeError();
430
431 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
432 if (auto Err =
433 parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
434 return Err;
435
436 if (auto Err =
437 parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
438 return Err;
439
440 if (auto Err =
441 parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
442 return Err;
443
444 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
446 "Vocabulary sections have different dimensions");
447
448 return Error::success();
449}
450
451void IR2VecVocabAnalysis::generateNumMappedVocab() {
452
453 // Helper for handling missing entities in the vocabulary.
454 // Currently, we use a zero vector. In the future, we will throw an error to
455 // ensure that *all* known entities are present in the vocabulary.
456 auto handleMissingEntity = [](const std::string &Val) {
457 LLVM_DEBUG(errs() << Val
458 << " is not in vocabulary, using zero vector; This "
459 "would result in an error in future.\n");
460 ++VocabMissCounter;
461 };
462
463 unsigned Dim = OpcVocab.begin()->second.size();
464 assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
465
466 // Handle Opcodes
467 std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
468 Embedding(Dim));
469 NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
470 for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
471 StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
472 auto It = OpcVocab.find(VocabKey.str());
473 if (It != OpcVocab.end())
474 NumericOpcodeEmbeddings[Opcode] = It->second;
475 else
476 handleMissingEntity(VocabKey.str());
477 }
478 Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
479 NumericOpcodeEmbeddings.end());
480
481 // Handle Types - only canonical types are present in vocabulary
482 std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
483 Embedding(Dim));
484 NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
485 for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
486 StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
487 static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
488 if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
489 NumericTypeEmbeddings[CTypeID] = It->second;
490 continue;
491 }
492 handleMissingEntity(VocabKey.str());
493 }
494 Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),
495 NumericTypeEmbeddings.end());
496
497 // Handle Arguments/Operands
498 std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
499 Embedding(Dim));
500 NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
501 for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
503 StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
504 auto It = ArgVocab.find(VocabKey.str());
505 if (It != ArgVocab.end()) {
506 NumericArgEmbeddings[OpKind] = It->second;
507 continue;
508 }
509 handleMissingEntity(VocabKey.str());
510 }
511 Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
512 NumericArgEmbeddings.end());
513}
514
516 : Vocab(Vocab) {}
517
519 : Vocab(std::move(Vocab)) {}
520
521void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
522 handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
523 Ctx.emitError("Error reading vocabulary: " + EI.message());
524 });
525}
526
529 auto Ctx = &M.getContext();
530 // If vocabulary is already populated by the constructor, use it.
531 if (!Vocab.empty())
532 return Vocabulary(std::move(Vocab));
533
534 // Otherwise, try to read from the vocabulary file.
535 if (VocabFile.empty()) {
536 // FIXME: Use default vocabulary
537 Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
538 "set it using --ir2vec-vocab-path");
539 return Vocabulary(); // Return invalid result
540 }
541 if (auto Err = readVocabulary()) {
542 emitError(std::move(Err), *Ctx);
543 return Vocabulary();
544 }
545
546 // Scale the vocabulary sections based on the provided weights
547 auto scaleVocabSection = [](VocabMap &Vocab, double Weight) {
548 for (auto &Entry : Vocab)
549 Entry.second *= Weight;
550 };
551 scaleVocabSection(OpcVocab, OpcWeight);
552 scaleVocabSection(TypeVocab, TypeWeight);
553 scaleVocabSection(ArgVocab, ArgWeight);
554
555 // Generate the numeric lookup vocabulary
556 generateNumMappedVocab();
557
558 return Vocabulary(std::move(Vocab));
559}
560
561// ==----------------------------------------------------------------------===//
562// Printer Passes
563//===----------------------------------------------------------------------===//
564
567 auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
568 assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
569
570 for (Function &F : M) {
572 if (!Emb) {
573 OS << "Error creating IR2Vec embeddings \n";
574 continue;
575 }
576
577 OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
578 OS << "Function vector: ";
579 Emb->getFunctionVector().print(OS);
580
581 OS << "Basic block vectors:\n";
582 const auto &BBMap = Emb->getBBVecMap();
583 for (const BasicBlock &BB : F) {
584 auto It = BBMap.find(&BB);
585 if (It != BBMap.end()) {
586 OS << "Basic block: " << BB.getName() << ":\n";
587 It->second.print(OS);
588 }
589 }
590
591 OS << "Instruction vectors:\n";
592 const auto &InstMap = Emb->getInstVecMap();
593 for (const BasicBlock &BB : F) {
594 for (const Instruction &I : BB) {
595 auto It = InstMap.find(&I);
596 if (It != InstMap.end()) {
597 OS << "Instruction: ";
598 I.print(OS);
599 It->second.print(OS);
600 }
601 }
602 }
603 }
604 return PreservedAnalyses::all();
605}
606
609 auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
610 assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
611
612 // Print each entry
613 unsigned Pos = 0;
614 for (const auto &Entry : IR2VecVocabulary) {
615 OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
616 Entry.print(OS);
617 }
618 return PreservedAnalyses::all();
619}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
#define _
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define P(N)
ModuleAnalysisManager MAM
Provides some synthesis utilities to produce sequences of values.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:167
#define LLVM_DEBUG(...)
Definition Debug.h:119
LLVM Basic Block Representation.
Definition BasicBlock.h:62
LLVM_ABI iterator_range< filter_iterator< BasicBlock::const_iterator, std::function< bool(const Instruction &)> > > instructionsWithoutDebug(bool SkipPseudoOp=true) const
Return a const iterator range over the instructions in the block, skipping any debug instructions.
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:165
iterator end()
Definition DenseMap.h:81
Base class for error info classes.
Definition Error.h:44
virtual std::string message() const
Return the error message as a string.
Definition Error.h:52
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
Tagged union holding either a T or a Error.
Definition Error.h:485
Error takeError()
Take ownership of the stored error.
Definition Error.h:612
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:565
This analysis provides the vocabulary for IR2Vec.
Definition IR2Vec.h:420
ir2vec::Vocabulary Result
Definition IR2Vec.h:437
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:528
static LLVM_ABI AnalysisKey Key
Definition IR2Vec.h:433
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:607
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
std::string str() const
str - Get the contents as an std::string.
Definition StringRef.h:233
TypeID
Definitions of all of the base types for the Type system.
Definition Type.h:54
LLVM Value Representation.
Definition Value.h:75
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
LLVM_ABI const Embedding & getBBVector(const BasicBlock &BB) const
Returns the embedding for a given basic block in the function F if it has been computed.
Definition IR2Vec.cpp:183
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
Definition IR2Vec.cpp:160
BBEmbeddingsMap BBVecMap
Definition IR2Vec.h:352
LLVM_ABI const BBEmbeddingsMap & getBBVecMap() const
Returns a map containing basic block and the corresponding embeddings for the function F if it has be...
Definition IR2Vec.cpp:177
const Vocabulary & Vocab
Definition IR2Vec.h:340
void computeEmbeddings() const
Function to compute embeddings.
Definition IR2Vec.cpp:198
const float TypeWeight
Definition IR2Vec.h:347
LLVM_ABI const InstEmbeddingsMap & getInstVecMap() const
Returns a map containing instructions and the corresponding embeddings for the function F if it has b...
Definition IR2Vec.cpp:171
const float OpcWeight
Weights for different entities (like opcode, arguments, types) in the IR instructions to generate the...
Definition IR2Vec.h:347
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
Definition IR2Vec.h:343
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
Definition IR2Vec.cpp:155
const float ArgWeight
Definition IR2Vec.h:347
Embedding FuncVector
Definition IR2Vec.h:351
LLVM_ABI const Embedding & getFunctionVector() const
Computes and returns the embedding for the current function.
Definition IR2Vec.cpp:191
InstEmbeddingsMap InstVecMap
Definition IR2Vec.h:353
const Function & F
Definition IR2Vec.h:339
Class for storing and accessing the IR2Vec vocabulary.
Definition IR2Vec.h:163
static LLVM_ABI unsigned getSlotIndex(unsigned Opcode)
Functions to return the slot index or position of a given Opcode, TypeID, or OperandKind in the vocab...
Definition IR2Vec.cpp:268
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
Definition IR2Vec.cpp:356
LLVM_ABI const ir2vec::Embedding & operator[](unsigned Opcode) const
Accessors to get the embedding for a given entity.
Definition IR2Vec.cpp:284
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
Definition IR2Vec.cpp:331
friend class llvm::IR2VecVocabAnalysis
Definition IR2Vec.h:164
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
Definition IR2Vec.cpp:341
static constexpr unsigned MaxCanonicalTypeIDs
Definition IR2Vec.h:206
static LLVM_ABI VocabVector createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
Definition IR2Vec.cpp:362
static constexpr unsigned MaxOperandKinds
Definition IR2Vec.h:208
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
Definition IR2Vec.h:192
static constexpr unsigned MaxTypeIDs
Definition IR2Vec.h:205
static LLVM_ABI StringRef getVocabKeyForTypeID(Type::TypeID TypeID)
Function to get vocabulary key for a given TypeID.
Definition IR2Vec.cpp:320
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
Definition IR2Vec.cpp:296
LLVM_ABI bool isValid() const
Definition IR2Vec.h:214
LLVM_ABI unsigned getDimension() const
Definition IR2Vec.cpp:263
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
Definition IR2Vec.h:175
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
Definition IR2Vec.cpp:324
An Object is a JSON object, which maps strings to heterogenous JSON values.
Definition JSON.h:98
LLVM_ABI Value * get(StringRef K)
Definition JSON.cpp:30
The root is the trivial Path to the root value.
Definition JSON.h:713
A "cursor" marking a position within a Value.
Definition JSON.h:666
A Value is an JSON value of unknown type.
Definition JSON.h:290
const json::Object * getAsObject() const
Definition JSON.h:464
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
DenseMap< const Instruction *, Embedding > InstEmbeddingsMap
Definition IR2Vec.h:143
static cl::opt< std::string > VocabFile("ir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), cl::cat(IR2VecCategory))
LLVM_ABI cl::opt< float > ArgWeight
DenseMap< const BasicBlock *, Embedding > BBEmbeddingsMap
Definition IR2Vec.h:144
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
llvm::cl::OptionCategory IR2VecCategory
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
Definition JSON.cpp:675
bool fromJSON(const Value &E, std::string &Out, Path P)
Definition JSON.h:742
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Definition Error.h:1399
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
void handleAllErrors(Error E, HandlerTs &&... Handlers)
Behaves the same as handleErrors, except that by contract all errors must be handled by the given han...
Definition Error.h:990
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition Error.h:1305
@ illegal_byte_sequence
Definition Errc.h:52
@ invalid_argument
Definition Errc.h:56
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
Definition IR2Vec.h:69
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
Definition Format.h:126
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1869
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:851
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
Embedding is a datatype that wraps std::vector<double>.
Definition IR2Vec.h:85
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
Definition IR2Vec.cpp:131
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
Definition IR2Vec.cpp:86
LLVM_ABI Embedding operator-(const Embedding &RHS) const
Definition IR2Vec.cpp:106
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
Definition IR2Vec.cpp:99
LLVM_ABI Embedding operator*(double Factor) const
Definition IR2Vec.cpp:118
size_t size() const
Definition IR2Vec.h:98
LLVM_ABI Embedding & operator*=(double Factor)
Definition IR2Vec.cpp:112
LLVM_ABI Embedding operator+(const Embedding &RHS) const
Definition IR2Vec.cpp:93
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
Definition IR2Vec.cpp:124
LLVM_ABI void print(raw_ostream &OS) const
Definition IR2Vec.cpp:144