diff --git a/include/libserial/ports.hpp b/include/libserial/ports.hpp index cf710e0..f3bf991 100644 --- a/include/libserial/ports.hpp +++ b/include/libserial/ports.hpp @@ -100,6 +100,15 @@ std::optional findBusPath(uint16_t id) const; std::optional findName(uint16_t id) const; private: +/** + * @brief Safely retrieves the target of a symlink + * + * @param symlink_path The path to the symlink + * @param target A reference to a string that will be populated with the target path + * @return true if the target was successfully retrieved, false otherwise + */ +void getSymlinkTarget(const std::string& symlink_path, std::string& target); + /** * @brief System path where udev creates symlinks for serial devices by ID */ diff --git a/include/libserial/serial_exception.hpp b/include/libserial/serial_exception.hpp index 7327e7d..1929a2b 100644 --- a/include/libserial/serial_exception.hpp +++ b/include/libserial/serial_exception.hpp @@ -87,6 +87,21 @@ explicit IOException(std::string message) } }; // class IOException +/** + * @class ScanPortsException + * @brief Exception class for port scanning errors + * + * The ScanPortsException class is derived from SerialException + * and is used to indicate that an error occurred while scanning + * for serial ports. + */ +class ScanPortsException : public SerialException { +public: +explicit ScanPortsException(std::string message) + : SerialException(std::move(message)) { +} +}; // class ScanPortsException + } // namespace libserial #endif // INCLUDE_LIBSERIAL_SERIAL_EXCEPTION_HPP_ diff --git a/src/ports.cpp b/src/ports.cpp index 37ae45c..8dbd69f 100644 --- a/src/ports.cpp +++ b/src/ports.cpp @@ -43,15 +43,21 @@ uint16_t Ports::scanPorts() { std::string symlink_path = std::string(by_id_dir) + "/" + device_name; // Store the relative path the symlink points to - char target[PATH_MAX] = {0}; - - // Resolve the symlink to get the actual device path relative to /dev - // from the /dev/serial/by-id/ directory (e.g., ../../ttyUSB0) - ssize_t len = readlink(symlink_path.c_str(), target, sizeof(target) - 1); - target[len] = '\0'; - + std::string target; + try { + getSymlinkTarget(symlink_path, target); + } catch (const ScanPortsException& e) { + throw ScanPortsException("Failed to get symlink target for " + symlink_path + ": " + e.what()); + continue; + } + // Resolve the relative path to an absolute path - const char* bname = strrchr(target, '/'); + const char* bname = strrchr(target.c_str(), '/'); + if (!bname) { + // No slash found - malformed symlink, skip + std::cerr << "Warning: Malformed symlink target for " << symlink_path << std::endl; + continue; + } bname++; // Construct the full /dev/ttyXXX path @@ -104,4 +110,33 @@ std::optional Ports::findName(uint16_t id) const { return it->getName(); } +void Ports::getSymlinkTarget(const std::string& symlink_path, std::string& target) { + struct stat stat_buf; + char buffer[PATH_MAX]; + + // Verify it's actually a symlink + if (lstat(symlink_path.c_str(), &stat_buf) == -1) { + throw ScanPortsException("lstat failed for " + symlink_path + ": " + + std::string(strerror(errno))); + } + + if (!S_ISLNK(stat_buf.st_mode)) { + throw ScanPortsException("Not a symlink"); + } + + // Read the symlink target + ssize_t len = readlink(symlink_path.c_str(), buffer, sizeof(buffer) - 1); + if (len < 0) { + throw ScanPortsException("Failed to read symlink: " + + std::string(strerror(errno))); + } else if (len >= static_cast(sizeof(buffer))) { + // Path too long - skip this entry + throw ScanPortsException("Symlink path too long: " + + std::string(strerror(errno))); + } + + buffer[len] = '\0'; + target.assign(buffer, len); +} + } // namespace libserial diff --git a/src/serial.cpp b/src/serial.cpp index 88b6248..04f91a4 100644 --- a/src/serial.cpp +++ b/src/serial.cpp @@ -54,27 +54,29 @@ size_t Serial::read(std::shared_ptr buffer, size_t max_length) { throw IOException("Null pointer passed to read function"); } - if (max_length > kMaxSafeReadSize) { - throw IOException("Read size exceeds maximum safe limit of " + - std::to_string(kMaxSafeReadSize) + " bytes"); - } + // Use the minimum of requested max_length and safe limit + size_t safe_max_length = std::min(max_length, kMaxSafeReadSize); - if (max_length == 0) { + // Check for zero length after applying safe limit TODO: check this point + if (safe_max_length == 0) { buffer->clear(); return 0; } - // Resize the string to accommodate the maximum possible data - buffer->resize(max_length); + // Ensure the buffer has enough capacity + if (buffer->capacity() < safe_max_length) { + buffer->reserve(safe_max_length); + } + + std::vector temp_buffer(safe_max_length); - // Use const_cast to get non-const pointer for read operation - ssize_t bytes_read = ::read(fd_serial_port_, const_cast(buffer->data()), max_length); + ssize_t bytes_read = ::read(fd_serial_port_, temp_buffer.data(), safe_max_length); if (bytes_read < 0) { throw IOException("Error reading from serial port: " + std::string(strerror(errno))); } - // Resize the string to the actual number of bytes read - buffer->resize(static_cast(bytes_read)); + // Safely assign to the string + buffer->assign(temp_buffer.data(), static_cast(bytes_read)); return static_cast(bytes_read); }