diff options
Diffstat (limited to 'fs/ksmbd/smb_common.c')
-rw-r--r-- | fs/ksmbd/smb_common.c | 60 |
1 files changed, 37 insertions, 23 deletions
diff --git a/fs/ksmbd/smb_common.c b/fs/ksmbd/smb_common.c index 43d3123d8b62..db8042a173d0 100644 --- a/fs/ksmbd/smb_common.c +++ b/fs/ksmbd/smb_common.c @@ -129,16 +129,22 @@ int ksmbd_lookup_protocol_idx(char *str) * * check for valid smb signature and packet direction(request/response) * - * Return: 0 on success, otherwise 1 + * Return: 0 on success, otherwise -EINVAL */ int ksmbd_verify_smb_message(struct ksmbd_work *work) { - struct smb2_hdr *smb2_hdr = work->request_buf; + struct smb2_hdr *smb2_hdr = work->request_buf + work->next_smb2_rcv_hdr_off; + struct smb_hdr *hdr; if (smb2_hdr->ProtocolId == SMB2_PROTO_NUMBER) return ksmbd_smb2_check_message(work); - return 0; + hdr = work->request_buf; + if (*(__le32 *)hdr->Protocol == SMB1_PROTO_NUMBER && + hdr->Command == SMB_COM_NEGOTIATE) + return 0; + + return -EINVAL; } /** @@ -149,20 +155,7 @@ int ksmbd_verify_smb_message(struct ksmbd_work *work) */ bool ksmbd_smb_request(struct ksmbd_conn *conn) { - int type = *(char *)conn->request_buf; - - switch (type) { - case RFC1002_SESSION_MESSAGE: - /* Regular SMB request */ - return true; - case RFC1002_SESSION_KEEP_ALIVE: - ksmbd_debug(SMB, "RFC 1002 session keep alive\n"); - break; - default: - ksmbd_debug(SMB, "RFC 1002 unknown request type 0x%x\n", type); - } - - return false; + return conn->request_buf[0] == 0; } static bool supported_protocol(int idx) @@ -176,10 +169,12 @@ static bool supported_protocol(int idx) idx <= server_conf.max_protocol); } -static char *next_dialect(char *dialect, int *next_off) +static char *next_dialect(char *dialect, int *next_off, int bcount) { dialect = dialect + *next_off; - *next_off = strlen(dialect); + *next_off = strnlen(dialect, bcount); + if (dialect[*next_off] != '\0') + return NULL; return dialect; } @@ -194,7 +189,9 @@ static int ksmbd_lookup_dialect_by_name(char *cli_dialects, __le16 byte_count) dialect = cli_dialects; bcount = le16_to_cpu(byte_count); do { - dialect = next_dialect(dialect, &next); + dialect = next_dialect(dialect, &next, bcount); + if (!dialect) + break; ksmbd_debug(SMB, "client requested dialect %s\n", dialect); if (!strcmp(dialect, smb1_protos[i].name)) { @@ -242,13 +239,22 @@ int ksmbd_lookup_dialect_by_id(__le16 *cli_dialects, __le16 dialects_count) static int ksmbd_negotiate_smb_dialect(void *buf) { - __le32 proto; + int smb_buf_length = get_rfc1002_len(buf); + __le32 proto = ((struct smb2_hdr *)buf)->ProtocolId; - proto = ((struct smb2_hdr *)buf)->ProtocolId; if (proto == SMB2_PROTO_NUMBER) { struct smb2_negotiate_req *req; + int smb2_neg_size = + offsetof(struct smb2_negotiate_req, Dialects) - 4; req = (struct smb2_negotiate_req *)buf; + if (smb2_neg_size > smb_buf_length) + goto err_out; + + if (smb2_neg_size + le16_to_cpu(req->DialectCount) * sizeof(__le16) > + smb_buf_length) + goto err_out; + return ksmbd_lookup_dialect_by_id(req->Dialects, req->DialectCount); } @@ -258,14 +264,22 @@ static int ksmbd_negotiate_smb_dialect(void *buf) struct smb_negotiate_req *req; req = (struct smb_negotiate_req *)buf; + if (le16_to_cpu(req->ByteCount) < 2) + goto err_out; + + if (offsetof(struct smb_negotiate_req, DialectsArray) - 4 + + le16_to_cpu(req->ByteCount) > smb_buf_length) { + goto err_out; + } + return ksmbd_lookup_dialect_by_name(req->DialectsArray, req->ByteCount); } +err_out: return BAD_PROT_ID; } -#define SMB_COM_NEGOTIATE 0x72 int ksmbd_init_smb_server(struct ksmbd_work *work) { struct ksmbd_conn *conn = work->conn; |