Implemented arithmetic operators and file input

This commit is contained in:
Benoy Bose 2026-01-27 16:07:40 +05:30
parent e93945f63a
commit e7a2a1940a
4 changed files with 171 additions and 23 deletions

View File

@ -2,7 +2,7 @@ CMAKE_MINIMUM_REQUIRED(VERSION 3.30 FATAL_ERROR)
PROJECT(HOOLANG)
SET (CMAKE_CXX_STANDARD 17)
SET (ANTLR_JAR /usr/local/lib/antlr4-4.13.2-complete.jar)
SET (ANTLR_JAR "${CMAKE_SOURCE_DIR}/antlr4-4.13.2-complete.jar")
SET (ANTLR_INCLUDE_DIR /usr/local/include/antlr4-runtime/)
SET (ANTLR_GENERATED_DIR "${CMAKE_BINARY_DIR}/antlr4/generated")
SET (GRAMMAR_FILE "${CMAKE_SOURCE_DIR}/Hoo.g4")

View File

@ -1,3 +1,30 @@
#include <iostream>
#include <fstream>
#include <sstream>
#include "Compiler.hpp"
int main(int argc, char* argv[]) {
if (argc != 2) {
std::cerr << "Usage: " << argv[0] << " <file>" << std::endl;
return 1;
}
std::ifstream file(argv[1]);
if (!file.is_open()) {
std::cerr << "Error: could not open file " << argv[1] << std::endl;
return 1;
}
std::stringstream buffer;
buffer << file.rdbuf();
std::string input = buffer.str();
try {
Compiler compiler(input, "hoo_module");
compiler.compile();
} catch (...) {
return 1;
}
return 0;
}

View File

@ -7,6 +7,7 @@
#include <llvm/IR/Constants.h>
#include <llvm/IR/Type.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/IR/Instructions.h>
Visitor::Visitor(const std::string &moduleName) : _moduleName(moduleName),
_context(std::make_shared<llvm::LLVMContext>()),
@ -15,7 +16,7 @@ Visitor::Visitor(const std::string &moduleName) : _moduleName(moduleName),
{
}
std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx)
llvm::Value *Visitor::visitLiteral(HooParser::LiteralContext *ctx)
{
auto value = ctx->INTEGER_LITERAL();
#ifndef NDEBUG
@ -33,7 +34,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx)
auto constantValue = llvm::dyn_cast<llvm::ConstantInt>(decimalConstant)->getValue().getSExtValue();
assert(constantValue == decimalValue);
#endif
return std::any{Node(NODE_LITERAL, DATATYPE_INTEGER, decimalConstant)};
return decimalConstant;
}
value = ctx->DOUBLE_LITERAL();
@ -43,7 +44,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx)
auto doubleValue = std::stod(doubleText);
llvm::Constant *doubleConstant = llvm::ConstantFP::get(
llvm::Type::getDoubleTy(*_context), doubleValue);
return std::any{Node(NODE_LITERAL, DATATYPE_DOUBLE, doubleConstant)};
return doubleConstant;
}
value = ctx->BOOL_LITERAL();
@ -53,7 +54,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx)
auto boolValue = boolText == "true" ? 1 : 0;
llvm::Type *boolType = llvm::Type::getInt1Ty(*_context);
llvm::Constant *boolConstant = llvm::ConstantInt::get(boolType, boolValue, true);
return std::any{Node(NODE_LITERAL, DATATYPE_BOOL, boolConstant)};
return boolConstant;
}
value = ctx->CHAR_LITERAL();
@ -73,7 +74,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx)
llvm::ConstantInt::get(byteType, charValue.bytes[0]),
llvm::ConstantInt::get(byteType, charValue.bytes[1]),
})});
return std::any{Node(NODE_LITERAL, DATATYPE_CHAR, charConstant)};
return charConstant;
}
value = ctx->STRING_LITERAL();
@ -81,7 +82,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx)
{
auto stringText = value->getText();
auto stringConstant = llvm::ConstantDataArray::getString(*_context, stringText);
return std::any{Node(NODE_LITERAL, DATATYPE_STRING, stringConstant)};
return stringConstant;
}
auto line_no = ctx->getStart()->getLine();
@ -93,7 +94,7 @@ std::any Visitor::visitLiteral(HooParser::LiteralContext *ctx)
throw ParseErrorException(_moduleName, line_no, char_pos, message);
}
std::any Visitor::visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx)
llvm::Value *Visitor::visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx)
{
auto literal = ctx->literal();
if (literal != nullptr)
@ -101,17 +102,17 @@ std::any Visitor::visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx)
auto node = visitLiteral(literal);
return node;
}
return std::any();
return nullptr;
}
std::any Visitor::visitNestedExpression(HooParser::NestedExpressionContext *ctx)
llvm::Value *Visitor::visitNestedExpression(HooParser::NestedExpressionContext *ctx)
{
auto expr_ctx = ctx->expression();
auto node = visit(expr_ctx);
return node;
}
std::any Visitor::visitExpressionStatement(HooParser::ExpressionStatementContext *ctx)
llvm::Value *Visitor::visitExpressionStatement(HooParser::ExpressionStatementContext *ctx)
{
auto expressionCtx = ctx->expression();
if (expressionCtx != nullptr)
@ -119,10 +120,10 @@ std::any Visitor::visitExpressionStatement(HooParser::ExpressionStatementContext
auto node = visit(expressionCtx);
return node;
}
return std::any();
return nullptr;
}
std::any Visitor::visitStatement(HooParser::StatementContext *ctx)
llvm::Value *Visitor::visitStatement(HooParser::StatementContext *ctx)
{
auto expr_stmt_ctx = ctx->expressionStatement();
if (expr_stmt_ctx != nullptr)
@ -130,16 +131,130 @@ std::any Visitor::visitStatement(HooParser::StatementContext *ctx)
auto node = visitExpressionStatement(expr_stmt_ctx);
return node;
}
return std::any();
return nullptr;
}
std::any Visitor::visitUnit(HooParser::UnitContext *ctx)
llvm::Value *Visitor::visitUnit(HooParser::UnitContext *ctx)
{
llvm::FunctionType* funcType = llvm::FunctionType::get(_builder->getInt32Ty(), false);
llvm::Function* mainFunc = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "main", _module.get());
llvm::BasicBlock* entry = llvm::BasicBlock::Create(*_context, "entry", mainFunc);
_builder->SetInsertPoint(entry);
auto stmt_ctx = ctx->statement();
if (stmt_ctx != nullptr)
{
auto result = visitStatement(stmt_ctx);
return result;
visitStatement(stmt_ctx);
}
_builder->CreateRet(llvm::ConstantInt::get(*_context, llvm::APInt(32, 0)));
_module->print(llvm::outs(), nullptr);
return nullptr;
}
llvm::Value *Visitor::visitMultiplicationExpression(HooParser::MultiplicationExpressionContext *ctx)
{
auto left = visit(ctx->expression(0));
auto right = visit(ctx->expression(1));
if (left->getType() != right->getType())
{
// TODO: handle type mismatch
return nullptr;
}
if (left->getType()->isIntegerTy())
{
return _builder->CreateMul(left, right, "multmp");
}
if (left->getType()->isDoubleTy())
{
return _builder->CreateFMul(left, right, "multmp");
}
return nullptr;
}
llvm::Value *Visitor::visitDivisionExpression(HooParser::DivisionExpressionContext *ctx)
{
auto left = visit(ctx->expression(0));
auto right = visit(ctx->expression(1));
if (left->getType() != right->getType())
{
// TODO: handle type mismatch
return nullptr;
}
if (left->getType()->isIntegerTy())
{
return _builder->CreateSDiv(left, right, "divtmp");
}
if (left->getType()->isDoubleTy())
{
return _builder->CreateFDiv(left, right, "divtmp");
}
return nullptr;
}
llvm::Value *Visitor::visitReminderExpression(HooParser::ReminderExpressionContext *ctx)
{
auto left = visit(ctx->expression(0));
auto right = visit(ctx->expression(1));
if (left->getType() != right->getType())
{
// TODO: handle type mismatch
return nullptr;
}
if (left->getType()->isIntegerTy())
{
return _builder->CreateSRem(left, right, "remtmp");
}
if (left->getType()->isDoubleTy())
{
return _builder->CreateFRem(left, right, "remtmp");
}
return nullptr;
}
llvm::Value *Visitor::visitAdditiveExpression(HooParser::AdditiveExpressionContext *ctx)
{
auto left = visit(ctx->expression(0));
auto right = visit(ctx->expression(1));
if (left->getType() != right->getType())
{
// TODO: handle type mismatch
return nullptr;
}
if (left->getType()->isIntegerTy())
{
return _builder->CreateAdd(left, right, "addtmp");
}
if (left->getType()->isDoubleTy())
{
return _builder->CreateFAdd(left, right, "addtmp");
}
return nullptr;
}
llvm::Value *Visitor::visitSubtractExpression(HooParser::SubtractExpressionContext *ctx)
{
auto left = visit(ctx->expression(0));
auto right = visit(ctx->expression(1));
if (left->getType() != right->getType())
{
// TODO: handle type mismatch
return nullptr;
}
if (left->getType()->isIntegerTy())
{
return _builder->CreateSub(left, right, "subtmp");
}
if (left->getType()->isDoubleTy())
{
return _builder->CreateFSub(left, right, "subtmp");
}
return nullptr;
}
llvm::Value *Visitor::visitPrimaryExpression(HooParser::PrimaryExpressionContext *ctx)
{
return visit(ctx->primary());
}

View File

@ -20,12 +20,18 @@ public:
Visitor(const std::string &moduleName);
public:
std::any visitLiteral(HooParser::LiteralContext *ctx) override;
std::any visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx) override;
std::any visitNestedExpression(HooParser::NestedExpressionContext *ctx) override;
std::any visitExpressionStatement(HooParser::ExpressionStatementContext *ctx) override;
std::any visitStatement(HooParser::StatementContext *ctx) override;
std::any visitUnit(HooParser::UnitContext *ctx) override;
llvm::Value *visitLiteral(HooParser::LiteralContext *ctx) override;
llvm::Value *visitPrimaryLiteral(HooParser::PrimaryLiteralContext *ctx) override;
llvm::Value *visitNestedExpression(HooParser::NestedExpressionContext *ctx) override;
llvm::Value *visitExpressionStatement(HooParser::ExpressionStatementContext *ctx) override;
llvm::Value *visitStatement(HooParser::StatementContext *ctx) override;
llvm::Value *visitUnit(HooParser::UnitContext *ctx) override;
llvm::Value *visitMultiplicationExpression(HooParser::MultiplicationExpressionContext *ctx) override;
llvm::Value *visitDivisionExpression(HooParser::DivisionExpressionContext *ctx) override;
llvm::Value *visitReminderExpression(HooParser::ReminderExpressionContext *ctx) override;
llvm::Value *visitAdditiveExpression(HooParser::AdditiveExpressionContext *ctx) override;
llvm::Value *visitSubtractExpression(HooParser::SubtractExpressionContext *ctx) override;
llvm::Value *visitPrimaryExpression(HooParser::PrimaryExpressionContext *ctx) override;
public:
std::shared_ptr<llvm::LLVMContext> getContext() const { return _context; }