pub mod frame;
mod message;
pub use self::message::Message;
pub use self::frame::CloseFrame;
use std::collections::VecDeque;
use std::io::{Read, Write, ErrorKind as IoErrorKind};
use std::mem::replace;
use error::{Error, Result};
use self::message::{IncompleteMessage, IncompleteMessageType};
use self::frame::{Frame, FrameSocket};
use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode};
use util::NonBlockingResult;
#[derive(Debug, Clone, Copy)]
pub enum Role {
Server,
Client,
}
#[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig {
pub max_send_queue: Option<usize>,
pub max_message_size: Option<usize>,
pub max_frame_size: Option<usize>,
}
impl Default for WebSocketConfig {
fn default() -> Self {
WebSocketConfig {
max_send_queue: None,
max_message_size: Some(64 << 20),
max_frame_size: Some(16 << 20),
}
}
}
#[derive(Debug)]
pub struct WebSocket<Stream> {
role: Role,
socket: FrameSocket<Stream>,
state: WebSocketState,
incomplete: Option<IncompleteMessage>,
send_queue: VecDeque<Frame>,
pong: Option<Frame>,
config: WebSocketConfig,
}
impl<Stream> WebSocket<Stream> {
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket::from_frame_socket(FrameSocket::new(stream), role, config)
}
pub fn from_partially_read(
stream: Stream,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
WebSocket::from_frame_socket(FrameSocket::from_partially_read(stream, part), role, config)
}
pub fn get_ref(&self) -> &Stream {
self.socket.get_ref()
}
pub fn get_mut(&mut self) -> &mut Stream {
self.socket.get_mut()
}
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config)
}
}
impl<Stream> WebSocket<Stream> {
fn from_frame_socket(
socket: FrameSocket<Stream>,
role: Role,
config: Option<WebSocketConfig>
) -> Self {
WebSocket {
role,
socket,
state: WebSocketState::Active,
incomplete: None,
send_queue: VecDeque::new(),
pong: None,
config: config.unwrap_or_else(WebSocketConfig::default),
}
}
}
impl<Stream: Read + Write> WebSocket<Stream> {
pub fn read_message(&mut self) -> Result<Message> {
loop {
self.write_pending().no_block()?;
let res = self.read_message_frame();
if let Some(message) = self.translate_close(res)? {
trace!("Received message {}", message);
return Ok(message)
}
}
}
pub fn write_message(&mut self, message: Message) -> Result<()> {
if let Some(max_send_queue) = self.config.max_send_queue {
if self.send_queue.len() >= max_send_queue {
self.write_pending().no_block()?;
}
if self.send_queue.len() >= max_send_queue {
return Err(Error::SendQueueFull(message));
}
}
let frame = match message {
Message::Text(data) => {
Frame::message(data.into(), OpCode::Data(OpData::Text), true)
}
Message::Binary(data) => {
Frame::message(data, OpCode::Data(OpData::Binary), true)
}
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.pong = Some(Frame::pong(data));
return self.write_pending()
}
};
self.send_queue.push_back(frame);
self.write_pending()
}
pub fn close(&mut self, code: Option<CloseFrame>) -> Result<()> {
if let WebSocketState::Active = self.state {
self.state = WebSocketState::ClosedByUs;
let frame = Frame::close(code);
self.send_queue.push_back(frame);
} else {
}
self.write_pending()
}
pub fn write_pending(&mut self) -> Result<()> {
{
let res = self.socket.write_pending();
self.translate_close(res)?;
}
if let Some(pong) = self.pong.take() {
self.send_one_frame(pong)?;
}
while let Some(data) = self.send_queue.pop_front() {
self.send_one_frame(data)?;
}
if let WebSocketState::ClosedByPeer(ref mut frame) = self.state {
match self.role {
Role::Client => Ok(()),
Role::Server => Err(Error::ConnectionClosed(frame.take())),
}
} else {
Ok(())
}
}
}
impl<Stream: Read + Write> WebSocket<Stream> {
fn read_message_frame(&mut self) -> Result<Option<Message>> {
if let Some(mut frame) = self.socket.read_frame(self.config.max_frame_size)? {
{
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into()))
}
}
match self.role {
Role::Server => {
if frame.is_masked() {
frame.apply_mask()
} else {
return Err(Error::Protocol("Received an unmasked frame from client".into()))
}
}
Role::Client => {
if frame.is_masked() {
return Err(Error::Protocol("Received a masked frame from server".into()))
}
}
}
match frame.header().opcode {
OpCode::Control(ctl) => {
match ctl {
_ if !frame.header().is_final => {
Err(Error::Protocol("Fragmented control frame".into()))
}
_ if frame.payload().len() > 125 => {
Err(Error::Protocol("Control frame too big".into()))
}
OpCtl::Close => {
self.do_close(frame.into_close()?).map(|_| None)
}
OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
}
OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => {
Ok(None)
}
OpCtl::Ping => {
let data = frame.into_data();
self.pong = Some(Frame::pong(data.clone()));
Ok(Some(Message::Ping(data)))
}
OpCtl::Pong => {
Ok(Some(Message::Pong(frame.into_data())))
}
}
}
OpCode::Data(_) if !self.state.is_active() => {
Ok(None)
}
OpCode::Data(data) => {
let fin = frame.header().is_final;
match data {
OpData::Continue => {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol("Continue frame but nothing to continue".into()))
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
} else {
Ok(None)
}
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into()
))
}
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
OpData::Text => IncompleteMessageType::Text,
OpData::Binary => IncompleteMessageType::Binary,
_ => panic!("Bug: message is not text nor binary"),
};
let mut m = IncompleteMessage::new(message_type);
m.extend(frame.into_data(), self.config.max_message_size)?;
m
};
if fin {
Ok(Some(msg.complete()?))
} else {
self.incomplete = Some(msg);
Ok(None)
}
}
OpData::Reserved(i) => {
Err(Error::Protocol(format!("Unknown data frame type {}", i).into()))
}
}
}
}
} else {
match replace(&mut self.state, WebSocketState::Terminated) {
WebSocketState::CloseAcknowledged(close) | WebSocketState::ClosedByPeer(close) => {
Err(Error::ConnectionClosed(close))
}
_ => {
Err(Error::Protocol("Connection reset without closing handshake".into()))
}
}
}
}
fn do_close(&mut self, close: Option<CloseFrame>) -> Result<()> {
debug!("Received close frame: {:?}", close);
match self.state {
WebSocketState::Active => {
let close_code = close.as_ref().map(|f| f.code);
self.state = WebSocketState::ClosedByPeer(close.map(CloseFrame::into_owned));
let reply = if let Some(code) = close_code {
if code.is_allowed() {
Frame::close(Some(CloseFrame {
code: CloseCode::Normal,
reason: "".into(),
}))
} else {
Frame::close(Some(CloseFrame {
code: CloseCode::Protocol,
reason: "Protocol violation".into()
}))
}
} else {
Frame::close(None)
};
debug!("Replying to close with {:?}", reply);
self.send_queue.push_back(reply);
Ok(())
}
WebSocketState::ClosedByPeer(_) | WebSocketState::CloseAcknowledged(_) => {
Ok(())
}
WebSocketState::ClosedByUs => {
let close = close.map(CloseFrame::into_owned);
match self.role {
Role::Client => {
self.state = WebSocketState::CloseAcknowledged(close);
Ok(())
}
Role::Server => {
Err(Error::ConnectionClosed(close))
}
}
}
WebSocketState::Terminated => unreachable!(),
}
}
fn send_one_frame(&mut self, mut frame: Frame) -> Result<()> {
match self.role {
Role::Server => {
}
Role::Client => {
frame.set_random_mask();
}
}
let res = self.socket.write_frame(frame);
self.translate_close(res)
}
fn translate_close<T>(&mut self, res: Result<T>) -> Result<T> {
match res {
Err(Error::Io(err)) => Err({
if err.kind() == IoErrorKind::ConnectionReset {
match self.state {
WebSocketState::ClosedByPeer(ref mut frame) =>
Error::ConnectionClosed(frame.take()),
WebSocketState::CloseAcknowledged(ref mut frame) =>
Error::ConnectionClosed(frame.take()),
_ => Error::Io(err),
}
} else {
Error::Io(err)
}
}),
x => x,
}
}
}
#[derive(Debug)]
enum WebSocketState {
Active,
ClosedByUs,
ClosedByPeer(Option<CloseFrame<'static>>),
CloseAcknowledged(Option<CloseFrame<'static>>),
Terminated,
}
impl WebSocketState {
fn is_active(&self) -> bool {
match *self {
WebSocketState::Active => true,
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::{WebSocket, Role, Message, WebSocketConfig};
use std::io;
use std::io::Cursor;
struct WriteMoc<Stream>(Stream);
impl<Stream> io::Write for WriteMoc<Stream> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<Stream: io::Read> io::Read for WriteMoc<Stream> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
#[test]
fn receive_messages() {
let incoming = Cursor::new(vec![
0x89, 0x02, 0x01, 0x02,
0x8a, 0x01, 0x03,
0x01, 0x07,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
0x82, 0x03,
0x01, 0x02, 0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
}
#[test]
fn size_limiting_text_fragmented() {
let incoming = Cursor::new(vec![
0x01, 0x07,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
]);
let limit = WebSocketConfig {
max_message_size: Some(10),
.. WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 7 + 6 > 10"
);
}
#[test]
fn size_limiting_binary() {
let incoming = Cursor::new(vec![
0x82, 0x03,
0x01, 0x02, 0x03,
]);
let limit = WebSocketConfig {
max_message_size: Some(2),
.. WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 0 + 3 > 2"
);
}
}