You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

363 lines
9.8 KiB

use std::sync::{Arc, Mutex};
use lazy_static::lazy_static;
use rusqlite::{params, Connection, Transaction};
use serenity::model::id::{ChannelId, GuildId, MessageId, RoleId, UserId};
use crate::error::{CrapBotError, Result};
lazy_static! {
// TODO remove lazy_static and work with dependency injection
static ref DB: Arc<Mutex<Connection>> = Arc::new(Mutex::new(init_db()));
}
fn init_db() -> Connection {
let conn = Connection::open("cache.db").expect("could not open database");
conn.execute_batch(include_str!("sql/initialize_db.sql"))
.expect("Failed to setup database");
conn
}
#[inline]
fn run_transaction<T>(callback: &dyn Fn(&mut Transaction) -> Result<T>) -> Result<T> {
let mut db = DB.lock().expect("could not acquire database lock");
let mut t = db.transaction()?; // automatically rolls back if it's not committed
let res = match callback(&mut t) {
Ok(res) => res,
Err(why) => {
eprintln!("transaction failed: {:?}", why);
return Err(why);
}
};
match t.commit() {
Ok(_) => {}
Err(why) => {
eprintln!("failed to commit transaction: {:?}", why);
return Err(why.into());
}
};
Ok(res)
}
fn user_can_delete_channel(guild: GuildId, channel: ChannelId, user: UserId) -> Result<()> {
let db = DB.lock().expect("could not acquire database lock");
let mut check_stmt =
db.prepare_cached("SELECT creator_id FROM created_channels where guild_id=? AND id=?")?;
let creator_id = check_stmt.query_row([guild.0, channel.0], |r| Ok(UserId(r.get(0)?)))?;
if creator_id != user {
return Err(CrapBotError::InsufficientPermissions);
}
Ok(())
}
pub fn add_admin(guild: GuildId, role: RoleId) -> Result<()> {
run_transaction(&|t| {
let mut insert_stmt =
t.prepare_cached("INSERT INTO admins (role_id, guild_id) VALUES (?, ?)")?;
insert_stmt.execute([role.0, guild.0])?;
Ok(())
})
}
pub fn delete_admin(guild: GuildId, role: RoleId) -> Result<()> {
run_transaction(&|t| {
let mut delete_stmt =
t.prepare_cached("DELETE FROM admins WHERE guild_id=? AND role_id=?")?;
delete_stmt.execute([guild.0, role.0])?;
Ok(())
})
}
pub fn get_admins_for_guild(guild: GuildId) -> Result<Vec<RoleId>> {
let db = DB.lock().expect("could not acquire database lock");
let mut select_stmt = db.prepare_cached("SELECT role_id FROM admins WHERE guild_id=?")?;
let roles = select_stmt
.query_map([guild.0], |row| Ok(RoleId(row.get(0)?)))?
.map(|r| r.unwrap())
.collect();
Ok(roles)
}
pub fn add_created_channels(
guild: GuildId,
channel: ChannelId,
user: UserId,
channels: Vec<ChannelId>,
) -> Result<()> {
run_transaction(&|t| {
let mut insert_stmt =
t.prepare_cached("INSERT INTO created_channels (id, guild_id, creator_id, parent_id) VALUES (?, ?, ?, ?)")?;
for c in channels.iter() {
insert_stmt.execute([c.0, guild.0, user.0, channel.0])?;
}
Ok(())
})
}
pub fn get_channels_for_user(guild: GuildId, user: UserId) -> Result<Vec<ChannelId>> {
let db = DB.lock().expect("could not acquire database lock");
let mut select_stmt = db.prepare_cached(
"SELECT DISTINCT parent_id FROM created_channels WHERE guild_id=? AND creator_id=?",
)?;
let channels = select_stmt
.query_map([guild.0, user.0], |row| Ok(ChannelId(row.get(0)?)))?
.map(|r| r.unwrap())
.collect();
Ok(channels)
}
pub fn get_channels_to_delete(
guild: GuildId,
channel: ChannelId,
user: Option<UserId>,
) -> Result<Vec<ChannelId>> {
if let Some(user) = user {
user_can_delete_channel(guild, channel, user)?;
}
let db = DB.lock().expect("could not acquire database lock");
let mut select_stmt =
db.prepare_cached("SELECT id FROM created_channels WHERE guild_id=? AND parent_id=?")?;
let channels: Vec<ChannelId> = select_stmt
.query_map([guild.0, channel.0], |row| Ok(ChannelId(row.get(0)?)))?
.map(|r| r.unwrap())
.collect();
if channels.is_empty() {
return Err(CrapBotError::NoEntry);
}
Ok(channels)
}
pub fn delete_created_channels(
guild: GuildId,
channel: ChannelId,
user: Option<UserId>,
) -> Result<()> {
if let Some(user) = user {
user_can_delete_channel(guild, channel, user)?;
}
run_transaction(&|t| {
let mut delete_stmt =
t.prepare_cached("DELETE FROM created_channels WHERE guild_id=? AND parent_id=?")?;
delete_stmt.execute([guild.0, channel.0])?;
Ok(())
})
}
pub fn add_emoji_role(
guild: GuildId,
message: MessageId,
emoji: String,
role: RoleId,
) -> Result<()> {
run_transaction(&|t| {
let mut insert_stmt = t.prepare_cached(
"INSERT INTO emoji_roles (guild_id, message_id, emoji, role_id) VALUES (?, ?, ?, ?)",
)?;
insert_stmt.execute(params![guild.0, message.0, emoji, role.0])?;
Ok(())
})
}
pub fn get_role_for_emoji(guild: GuildId, message: MessageId, emoji: String) -> Result<RoleId> {
let db = DB.lock().expect("could not acquire database lock");
let mut select_stmt = db.prepare_cached(
"SELECT role_id FROM emoji_roles WHERE guild_id=? AND message_id=? AND emoji=?",
)?;
let role = select_stmt.query_row(params![guild.0, message.0, emoji], |row| {
Ok(RoleId(row.get(0)?))
})?;
Ok(role)
}
pub fn delete_emoji_role(guild: GuildId, message: MessageId, emoji: String) -> Result<()> {
run_transaction(&|t| {
let mut delete = t.prepare_cached(
"DELETE FROM emoji_roles WHERE guild_id=? AND message_id=? AND emoji=?",
)?;
delete.execute(params![guild.0, message.0, emoji])?;
Ok(())
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_test_admins() {
// Role "Botmasters" on test server
add_admin(GuildId(783019282404737086), RoleId(783020797479878656)).unwrap();
}
#[test]
fn test_get_admins() {
// setup
add_admin(GuildId(42), RoleId(21)).unwrap();
assert_eq!(vec![RoleId(21)], get_admins_for_guild(GuildId(42)).unwrap());
// cleanup
delete_admin(GuildId(42), RoleId(21)).unwrap();
}
#[test]
fn test_add_admin() {
assert_eq!((), add_admin(GuildId(42), RoleId(21)).unwrap());
// cleanup
delete_admin(GuildId(42), RoleId(21)).unwrap();
}
#[test]
fn test_delete_admin() {
// setup
add_admin(GuildId(42), RoleId(21)).unwrap();
assert_eq!((), delete_admin(GuildId(42), RoleId(21)).unwrap());
}
#[test]
fn test_add_created_channels() {
assert_eq!(
(),
add_created_channels(
GuildId(42),
ChannelId(21),
UserId(11),
vec![ChannelId(21), ChannelId(22), ChannelId(23)]
)
.unwrap()
);
// cleanup
delete_created_channels(GuildId(42), ChannelId(21), None).unwrap();
}
#[test]
fn test_delete_created_channels() {
// setup
add_created_channels(
GuildId(42),
ChannelId(21),
UserId(11),
vec![ChannelId(21), ChannelId(22), ChannelId(23)],
)
.unwrap();
assert_eq!(
(),
delete_created_channels(GuildId(42), ChannelId(21), None).unwrap()
);
assert!(
match get_channels_to_delete(GuildId(42), ChannelId(21), None).unwrap_err() {
CrapBotError::NoEntry => true,
_ => false,
}
);
}
#[test]
fn test_get_channels_to_delete() {
// setup
add_created_channels(
GuildId(42),
ChannelId(21),
UserId(11),
vec![ChannelId(21), ChannelId(22), ChannelId(23)],
)
.unwrap();
assert_eq!(
vec![ChannelId(21), ChannelId(22), ChannelId(23)],
get_channels_to_delete(GuildId(42), ChannelId(21), Some(UserId(11))).unwrap()
);
// cleanup
delete_created_channels(GuildId(42), ChannelId(21), None).unwrap();
}
#[test]
fn test_get_channels_for_user() {
// setup
add_created_channels(
GuildId(42),
ChannelId(21),
UserId(11),
vec![ChannelId(21), ChannelId(22), ChannelId(23)],
)
.unwrap();
assert_eq!(
vec![ChannelId(21)],
get_channels_for_user(GuildId(42), UserId(11)).unwrap()
);
// cleanup
delete_created_channels(GuildId(42), ChannelId(21), None).unwrap();
}
#[test]
fn test_add_emoji_role() {
assert_eq!(
(),
add_emoji_role(GuildId(42), MessageId(21), "😂".into(), RoleId(5)).unwrap()
);
// cleanup
delete_emoji_role(GuildId(42), MessageId(21), "😂".into()).unwrap();
}
#[test]
fn test_get_role_for_emoji() {
// setup
add_emoji_role(GuildId(42), MessageId(21), "😂".into(), RoleId(5)).unwrap();
assert_eq!(
RoleId(5),
get_role_for_emoji(GuildId(42), MessageId(21), "😂".into()).unwrap()
);
// cleanup
delete_emoji_role(GuildId(42), MessageId(21), "😂".into()).unwrap();
}
#[test]
fn test() {
// setup
add_emoji_role(GuildId(42), MessageId(21), "😂".into(), RoleId(5)).unwrap();
assert_eq!(
(),
delete_emoji_role(GuildId(42), MessageId(21), "😂".into()).unwrap()
);
assert!(
match get_role_for_emoji(GuildId(42), MessageId(21), "😂".into()).unwrap_err() {
CrapBotError::NoEntry => true,
_ => false,
}
);
}
}