summaryrefslogtreecommitdiffstats
path: root/fs/ksmbd/mgmt/user_session.c
diff options
context:
space:
mode:
Diffstat (limited to 'fs/ksmbd/mgmt/user_session.c')
-rw-r--r--fs/ksmbd/mgmt/user_session.c95
1 files changed, 53 insertions, 42 deletions
diff --git a/fs/ksmbd/mgmt/user_session.c b/fs/ksmbd/mgmt/user_session.c
index 8d8ffd8c6f19..3fa2139a0b30 100644
--- a/fs/ksmbd/mgmt/user_session.c
+++ b/fs/ksmbd/mgmt/user_session.c
@@ -32,11 +32,13 @@ static void free_channel_list(struct ksmbd_session *sess)
{
struct channel *chann, *tmp;
+ write_lock(&sess->chann_lock);
list_for_each_entry_safe(chann, tmp, &sess->ksmbd_chann_list,
chann_list) {
list_del(&chann->chann_list);
kfree(chann);
}
+ write_unlock(&sess->chann_lock);
}
static void __session_rpc_close(struct ksmbd_session *sess,
@@ -149,11 +151,6 @@ void ksmbd_session_destroy(struct ksmbd_session *sess)
if (!sess)
return;
- if (!atomic_dec_and_test(&sess->refcnt))
- return;
-
- list_del(&sess->sessions_entry);
-
down_write(&sessions_table_lock);
hash_del(&sess->hlist);
up_write(&sessions_table_lock);
@@ -181,53 +178,70 @@ static struct ksmbd_session *__session_lookup(unsigned long long id)
return NULL;
}
-void ksmbd_session_register(struct ksmbd_conn *conn,
- struct ksmbd_session *sess)
+int ksmbd_session_register(struct ksmbd_conn *conn,
+ struct ksmbd_session *sess)
{
- sess->conn = conn;
- list_add(&sess->sessions_entry, &conn->sessions);
+ sess->dialect = conn->dialect;
+ memcpy(sess->ClientGUID, conn->ClientGUID, SMB2_CLIENT_GUID_SIZE);
+ return xa_err(xa_store(&conn->sessions, sess->id, sess, GFP_KERNEL));
}
-void ksmbd_sessions_deregister(struct ksmbd_conn *conn)
+static int ksmbd_chann_del(struct ksmbd_conn *conn, struct ksmbd_session *sess)
{
- struct ksmbd_session *sess;
-
- while (!list_empty(&conn->sessions)) {
- sess = list_entry(conn->sessions.next,
- struct ksmbd_session,
- sessions_entry);
+ struct channel *chann, *tmp;
- ksmbd_session_destroy(sess);
+ write_lock(&sess->chann_lock);
+ list_for_each_entry_safe(chann, tmp, &sess->ksmbd_chann_list,
+ chann_list) {
+ if (chann->conn == conn) {
+ list_del(&chann->chann_list);
+ kfree(chann);
+ write_unlock(&sess->chann_lock);
+ return 0;
+ }
}
-}
+ write_unlock(&sess->chann_lock);
-static bool ksmbd_session_id_match(struct ksmbd_session *sess,
- unsigned long long id)
-{
- return sess->id == id;
+ return -ENOENT;
}
-struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn,
- unsigned long long id)
+void ksmbd_sessions_deregister(struct ksmbd_conn *conn)
{
- struct ksmbd_session *sess = NULL;
+ struct ksmbd_session *sess;
- list_for_each_entry(sess, &conn->sessions, sessions_entry) {
- if (ksmbd_session_id_match(sess, id))
- return sess;
+ if (conn->binding) {
+ int bkt;
+
+ down_write(&sessions_table_lock);
+ hash_for_each(sessions_table, bkt, sess, hlist) {
+ if (!ksmbd_chann_del(conn, sess)) {
+ up_write(&sessions_table_lock);
+ goto sess_destroy;
+ }
+ }
+ up_write(&sessions_table_lock);
+ } else {
+ unsigned long id;
+
+ xa_for_each(&conn->sessions, id, sess) {
+ if (!ksmbd_chann_del(conn, sess))
+ goto sess_destroy;
+ }
}
- return NULL;
-}
-int get_session(struct ksmbd_session *sess)
-{
- return atomic_inc_not_zero(&sess->refcnt);
+ return;
+
+sess_destroy:
+ if (list_empty(&sess->ksmbd_chann_list)) {
+ xa_erase(&conn->sessions, sess->id);
+ ksmbd_session_destroy(sess);
+ }
}
-void put_session(struct ksmbd_session *sess)
+struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn,
+ unsigned long long id)
{
- if (atomic_dec_and_test(&sess->refcnt))
- pr_err("get/%s seems to be mismatched.", __func__);
+ return xa_load(&conn->sessions, id);
}
struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id)
@@ -236,10 +250,6 @@ struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id)
down_read(&sessions_table_lock);
sess = __session_lookup(id);
- if (sess) {
- if (!get_session(sess))
- sess = NULL;
- }
up_read(&sessions_table_lock);
return sess;
@@ -253,6 +263,8 @@ struct ksmbd_session *ksmbd_session_lookup_all(struct ksmbd_conn *conn,
sess = ksmbd_session_lookup(conn, id);
if (!sess && conn->binding)
sess = ksmbd_session_lookup_slowpath(id);
+ if (sess && sess->state != SMB2_SESSION_VALID)
+ sess = NULL;
return sess;
}
@@ -314,12 +326,11 @@ static struct ksmbd_session *__session_create(int protocol)
goto error;
set_session_flag(sess, protocol);
- INIT_LIST_HEAD(&sess->sessions_entry);
xa_init(&sess->tree_conns);
INIT_LIST_HEAD(&sess->ksmbd_chann_list);
INIT_LIST_HEAD(&sess->rpc_handle_list);
sess->sequence_number = 1;
- atomic_set(&sess->refcnt, 1);
+ rwlock_init(&sess->chann_lock);
switch (protocol) {
case CIFDS_SESSION_FLAG_SMB2: