diff --git a/CMakeLists.txt b/CMakeLists.txt index 241d93e..8dc103e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ CMAKE_MINIMUM_REQUIRED(VERSION 3.30 FATAL_ERROR) PROJECT(HOOLANG) SET (CMAKE_CXX_STANDARD 17) -SET (ANTLR_JAR /usr/local/lib/antlr4-4.13.2-complete.jar) +SET (ANTLR_JAR "${CMAKE_SOURCE_DIR}/antlr4-4.13.2-complete.jar") SET (ANTLR_INCLUDE_DIR /usr/local/include/antlr4-runtime/) SET (ANTLR_GENERATED_DIR "${CMAKE_BINARY_DIR}/antlr4/generated") SET (GRAMMAR_FILE "${CMAKE_SOURCE_DIR}/Hoo.g4") diff --git a/src/Hoo.cpp b/src/Hoo.cpp index 0373b85..44813e9 100644 --- a/src/Hoo.cpp +++ b/src/Hoo.cpp @@ -1,3 +1,30 @@ +#include +#include +#include +#include "Compiler.hpp" + int main(int argc, char* argv[]) { + if (argc != 2) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + + std::ifstream file(argv[1]); + if (!file.is_open()) { + std::cerr << "Error: could not open file " << argv[1] << std::endl; + return 1; + } + + std::stringstream buffer; + buffer << file.rdbuf(); + std::string input = buffer.str(); + + try { + Compiler compiler(input, "hoo_module"); + compiler.compile(); + } catch (...) { + return 1; + } + return 0; } \ No newline at end of file diff --git a/src/Visitor.cpp b/src/Visitor.cpp index 9bac1cd..f5f7e67 100644 --- a/src/Visitor.cpp +++ b/src/Visitor.cpp @@ -7,6 +7,7 @@ #include #include #include +#include Visitor::Visitor(const std::string &moduleName) : _moduleName(moduleName), _context(std::make_shared()), @@ -15,7 +16,7 @@ Visitor::Visitor(const std::string &moduleName) : _moduleName(moduleName), { } -std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx) +llvm::Value *Visitor::visitLiteral(HooParser::LiteralContext *ctx) { auto value = ctx->INTEGER_LITERAL(); #ifndef NDEBUG @@ -33,7 +34,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx) auto constantValue = llvm::dyn_cast(decimalConstant)->getValue().getSExtValue(); assert(constantValue == decimalValue); #endif - return std::any{Node(NODE_LITERAL, DATATYPE_INTEGER, decimalConstant)}; + return decimalConstant; } value = ctx->DOUBLE_LITERAL(); @@ -43,7 +44,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx) auto doubleValue = std::stod(doubleText); llvm::Constant *doubleConstant = llvm::ConstantFP::get( llvm::Type::getDoubleTy(*_context), doubleValue); - return std::any{Node(NODE_LITERAL, DATATYPE_DOUBLE, doubleConstant)}; + return doubleConstant; } value = ctx->BOOL_LITERAL(); @@ -53,7 +54,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx) auto boolValue = boolText == "true" ? 1 : 0; llvm::Type *boolType = llvm::Type::getInt1Ty(*_context); llvm::Constant *boolConstant = llvm::ConstantInt::get(boolType, boolValue, true); - return std::any{Node(NODE_LITERAL, DATATYPE_BOOL, boolConstant)}; + return boolConstant; } value = ctx->CHAR_LITERAL(); @@ -73,7 +74,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx) llvm::ConstantInt::get(byteType, charValue.bytes[0]), llvm::ConstantInt::get(byteType, charValue.bytes[1]), })}); - return std::any{Node(NODE_LITERAL, DATATYPE_CHAR, charConstant)}; + return charConstant; } value = ctx->STRING_LITERAL(); @@ -81,7 +82,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx) { auto stringText = value->getText(); auto stringConstant = llvm::ConstantDataArray::getString(*_context, stringText); - return std::any{Node(NODE_LITERAL, DATATYPE_STRING, stringConstant)}; + return stringConstant; } auto line_no = ctx->getStart()->getLine(); @@ -93,7 +94,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx) throw ParseErrorException(_moduleName, line_no, char_pos, message); } -std::any Visitor::visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx) +llvm::Value *Visitor::visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx) { auto literal = ctx->literal(); if (literal != nullptr) @@ -101,17 +102,17 @@ std::any Visitor::visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx) auto node = visitLiteral(literal); return node; } - return std::any(); + return nullptr; } -std::any Visitor::visitNestedExpression(HooParser::NestedExpressionContext *ctx) +llvm::Value *Visitor::visitNestedExpression(HooParser::NestedExpressionContext *ctx) { auto expr_ctx = ctx->expression(); auto node = visit(expr_ctx); return node; } -std::any Visitor::visitExpressionStatement(HooParser::ExpressionStatementContext *ctx) +llvm::Value *Visitor::visitExpressionStatement(HooParser::ExpressionStatementContext *ctx) { auto expressionCtx = ctx->expression(); if (expressionCtx != nullptr) @@ -119,10 +120,10 @@ std::any Visitor::visitExpressionStatement(HooParser::ExpressionStatementContext auto node = visit(expressionCtx); return node; } - return std::any(); + return nullptr; } -std::any Visitor::visitStatement(HooParser::StatementContext *ctx) +llvm::Value *Visitor::visitStatement(HooParser::StatementContext *ctx) { auto expr_stmt_ctx = ctx->expressionStatement(); if (expr_stmt_ctx != nullptr) @@ -130,16 +131,130 @@ std::any Visitor::visitStatement(HooParser::StatementContext *ctx) auto node = visitExpressionStatement(expr_stmt_ctx); return node; } - return std::any(); + return nullptr; } -std::any Visitor::visitUnit(HooParser::UnitContext *ctx) +llvm::Value *Visitor::visitUnit(HooParser::UnitContext *ctx) { + llvm::FunctionType* funcType = llvm::FunctionType::get(_builder->getInt32Ty(), false); + llvm::Function* mainFunc = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "main", _module.get()); + llvm::BasicBlock* entry = llvm::BasicBlock::Create(*_context, "entry", mainFunc); + _builder->SetInsertPoint(entry); + auto stmt_ctx = ctx->statement(); if (stmt_ctx != nullptr) { - auto result = visitStatement(stmt_ctx); - return result; + visitStatement(stmt_ctx); + } + + _builder->CreateRet(llvm::ConstantInt::get(*_context, llvm::APInt(32, 0))); + + _module->print(llvm::outs(), nullptr); + + return nullptr; +} + +llvm::Value *Visitor::visitMultiplicationExpression(HooParser::MultiplicationExpressionContext *ctx) +{ + auto left = visit(ctx->expression(0)); + auto right = visit(ctx->expression(1)); + if (left->getType() != right->getType()) + { + // TODO: handle type mismatch + return nullptr; + } + if (left->getType()->isIntegerTy()) + { + return _builder->CreateMul(left, right, "multmp"); + } + if (left->getType()->isDoubleTy()) + { + return _builder->CreateFMul(left, right, "multmp"); } return nullptr; +} + +llvm::Value *Visitor::visitDivisionExpression(HooParser::DivisionExpressionContext *ctx) +{ + auto left = visit(ctx->expression(0)); + auto right = visit(ctx->expression(1)); + if (left->getType() != right->getType()) + { + // TODO: handle type mismatch + return nullptr; + } + if (left->getType()->isIntegerTy()) + { + return _builder->CreateSDiv(left, right, "divtmp"); + } + if (left->getType()->isDoubleTy()) + { + return _builder->CreateFDiv(left, right, "divtmp"); + } + return nullptr; +} + +llvm::Value *Visitor::visitReminderExpression(HooParser::ReminderExpressionContext *ctx) +{ + auto left = visit(ctx->expression(0)); + auto right = visit(ctx->expression(1)); + if (left->getType() != right->getType()) + { + // TODO: handle type mismatch + return nullptr; + } + if (left->getType()->isIntegerTy()) + { + return _builder->CreateSRem(left, right, "remtmp"); + } + if (left->getType()->isDoubleTy()) + { + return _builder->CreateFRem(left, right, "remtmp"); + } + return nullptr; +} + +llvm::Value *Visitor::visitAdditiveExpression(HooParser::AdditiveExpressionContext *ctx) +{ + auto left = visit(ctx->expression(0)); + auto right = visit(ctx->expression(1)); + if (left->getType() != right->getType()) + { + // TODO: handle type mismatch + return nullptr; + } + if (left->getType()->isIntegerTy()) + { + return _builder->CreateAdd(left, right, "addtmp"); + } + if (left->getType()->isDoubleTy()) + { + return _builder->CreateFAdd(left, right, "addtmp"); + } + return nullptr; +} + +llvm::Value *Visitor::visitSubtractExpression(HooParser::SubtractExpressionContext *ctx) +{ + auto left = visit(ctx->expression(0)); + auto right = visit(ctx->expression(1)); + if (left->getType() != right->getType()) + { + // TODO: handle type mismatch + return nullptr; + } + if (left->getType()->isIntegerTy()) + { + return _builder->CreateSub(left, right, "subtmp"); + } + if (left->getType()->isDoubleTy()) + { + return _builder->CreateFSub(left, right, "subtmp"); + } + return nullptr; +} + +llvm::Value *Visitor::visitPrimaryExpression(HooParser::PrimaryExpressionContext *ctx) +{ + return visit(ctx->primary()); } \ No newline at end of file diff --git a/src/Visitor.hpp b/src/Visitor.hpp index b327400..3670c12 100644 --- a/src/Visitor.hpp +++ b/src/Visitor.hpp @@ -20,12 +20,18 @@ public: Visitor(const std::string &moduleName); public: - std::any visitLiteral(HooParser::LiteralContext *ctx) override; - std::any visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx) override; - std::any visitNestedExpression(HooParser::NestedExpressionContext *ctx) override; - std::any visitExpressionStatement(HooParser::ExpressionStatementContext *ctx) override; - std::any visitStatement(HooParser::StatementContext *ctx) override; - std::any visitUnit(HooParser::UnitContext *ctx) override; + llvm::Value *visitLiteral(HooParser::LiteralContext *ctx) override; + llvm::Value *visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx) override; + llvm::Value *visitNestedExpression(HooParser::NestedExpressionContext *ctx) override; + llvm::Value *visitExpressionStatement(HooParser::ExpressionStatementContext *ctx) override; + llvm::Value *visitStatement(HooParser::StatementContext *ctx) override; + llvm::Value *visitUnit(HooParser::UnitContext *ctx) override; + llvm::Value *visitMultiplicationExpression(HooParser::MultiplicationExpressionContext *ctx) override; + llvm::Value *visitDivisionExpression(HooParser::DivisionExpressionContext *ctx) override; + llvm::Value *visitReminderExpression(HooParser::ReminderExpressionContext *ctx) override; + llvm::Value *visitAdditiveExpression(HooParser::AdditiveExpressionContext *ctx) override; + llvm::Value *visitSubtractExpression(HooParser::SubtractExpressionContext *ctx) override; + llvm::Value *visitPrimaryExpression(HooParser::PrimaryExpressionContext *ctx) override; public: std::shared_ptr getContext() const { return _context; }